From be01e1891669acc2220c48b36ea7ffc24bebb653 Mon Sep 17 00:00:00 2001 From: Trianta <56975502+Trimutex@users.noreply.github.com> Date: Sun, 12 Nov 2023 19:01:59 -0600 Subject: [PATCH] Fixed probabilities not adjusting mathematically --- src/filter.cpp | 35 +++++++++++++++++++++-------------- src/filter.hpp | 7 ++++++- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index 4d36a2f..be36846 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -42,13 +42,13 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) { while (std::getline(full_message, token, ' ')) { token = SanitizeToken(token); if (token.empty()) { continue; } - if (probability_dictionary[token] == 0) { - probability_dictionary[token] = 0.5; - } if (is_ham_temp) { - probability_dictionary[token] += probability_dictionary[token] * 0.1; + probability_dictionary[token].value += probability_dictionary[token].value * 0.0000000001; + if (probability_dictionary[token].value > 1.) { + probability_dictionary[token].value = 1.; + } } else { - probability_dictionary[token] -= probability_dictionary[token] * 0.1; + probability_dictionary[token].value -= probability_dictionary[token].value * 0.0000000001; } } temp_message.clear(); @@ -82,7 +82,6 @@ void SMSMessageFilter::Prepare(std::string file_name) { full_message.clear(); continue; // Probably a bad line cut } - full_message.ignore('\t'); filtered_messages.emplace_back(is_ham_temp, full_message.str()); temp_message.clear(); full_message.clear(); @@ -97,15 +96,18 @@ void SMSMessageFilter::Filter(void) { full_message.str(filtered_messages[i].message); while (std::getline(full_message, token, ' ')) { token = SanitizeToken(token); - if (probability_dictionary[token] == 0) { - probability_dictionary[token] = 0.5; - } - type_probability = probability_dictionary[token] * type_probability; + type_probability = probability_dictionary[token].value * type_probability; } + double final_probability; + final_probability = (1. - sentence_probability_ham) * type_probability; + final_probability = final_probability / (final_probability + ((1. - type_probability) * sentence_probability_ham)); if (type_probability <= sentence_probability_ham) { filtered_messages[i].is_ham_filter = true; } else { filtered_messages[i].is_ham_filter = false; } + std::cout << "[SMSMessageFilter - Info] Final probability of " + << i << ": " << final_probability << std::endl; type_probability = 0.5; + full_message.clear(); } } @@ -114,10 +116,10 @@ void SMSMessageFilter::Report(void) { } ReportData SMSMessageFilter::GenerateReport(void) { - double true_ham = 0; - double true_spam = 0; - double false_ham = 0; - double false_spam = 0; + double true_ham = 0.; + double true_spam = 0.; + double false_ham = 0.; + double false_spam = 0.; for (SMSMessage message : filtered_messages) { // Get total count if (!(message.is_ham ^ message.is_ham_filter)) { @@ -129,6 +131,7 @@ ReportData SMSMessageFilter::GenerateReport(void) { } } + std::cout << std::endl; std::cout << "[SMSMessageFilter - Info] Ham barrier: "; std::cout << sentence_probability_ham << std::endl; std::cout << "[SMSMessageFilter - Info] True ham: "; @@ -139,6 +142,7 @@ ReportData SMSMessageFilter::GenerateReport(void) { std::cout << false_ham << std::endl; std::cout << "[SMSMessageFilter - Info] False spam: "; std::cout << false_spam << std::endl; + std::cout << std::endl; // Calculate report data ReportData new_report; @@ -153,6 +157,8 @@ ReportData SMSMessageFilter::GenerateReport(void) { } void SMSMessageFilter::PrintReport(ReportData report) { + std::cout << std::endl; + std::cout << "============ [SMSMessageFilter - Report - Start] ============" << std::endl; // Spam precision: (true positives) / (true positives + false positives) std::cout << "[SMSMessageFilter - Report] Spam precision: "; std::cout << report.spam_precision << std::endl; @@ -180,6 +186,7 @@ void SMSMessageFilter::PrintReport(ReportData report) { // Accuracy: (spam recall + ham recall) / 2 std::cout << "[SMSMessageFilter - Report] Accuracy: "; std::cout << report.accuracy << std::endl; + std::cout << "============ [SMSMessageFilter - Report - End] ============" << std::endl; } std::string SanitizeToken(std::string token) { diff --git a/src/filter.hpp b/src/filter.hpp index abc1f9f..31a1579 100644 --- a/src/filter.hpp +++ b/src/filter.hpp @@ -5,6 +5,11 @@ #include #include +// C++ maps don't allow specifying default value +struct DoubleDefaultedToHalf { + double value = 0.5; +}; + struct ReportData { double spam_precision; double spam_recall; @@ -33,7 +38,7 @@ public: private: double sentence_probability_ham = 0.5; // Spam is 1 - sentence_probability_ham - std::map probability_dictionary; + std::map probability_dictionary; std::vector filtered_messages; ReportData GenerateReport(void); void PrintReport(ReportData report);