From 9b3d373ee02e4935068a3404750e4eb58bb185e9 Mon Sep 17 00:00:00 2001 From: Trianta <56975502+Trimutex@users.noreply.github.com> Date: Sun, 12 Nov 2023 05:16:44 -0600 Subject: [PATCH] Main part of filter working. Math seems to be wrong multiplying floats --- src/filter.cpp | 150 ++++++++++++++++++++++++++++++++++++++++++------- src/filter.hpp | 26 ++++++--- src/main.cpp | 8 ++- 3 files changed, 154 insertions(+), 30 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index bb421f3..4d36a2f 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -1,10 +1,14 @@ #include "filter.hpp" +#include +#include #include #include #include +#include +#include -SMSMessage::SMSMessage(MessageType given_type, std::string given_message) { - this->actual_type = given_type; +SMSMessage::SMSMessage(bool given_type, std::string given_message) { + this->is_ham = given_type; this->message = given_message; } @@ -18,22 +22,41 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) { } // Read in messages - MessageType temp_type; - std::string temp_message; + bool is_ham_temp; + std::string temp_message, token; + std::stringstream full_message; while (!input_messages.eof()) { - input_messages.get(&temp_message[0], 10, '\t'); - if (temp_message == "spam") { temp_type = kSpam; } - else if (temp_message == "ham") { temp_type = kHam; } + std::getline(input_messages, temp_message); + full_message.str(temp_message); + std::getline(full_message, token, '\t'); + if (token == "ham") { is_ham_temp = true; } + else if (token == "spam") { is_ham_temp = false; } else { std::cerr << "[SMSMessageFilter - Warning]" << " Could not determine message type" << std::endl; - std::cerr << "\t Contains: <" << temp_message << ">" << std::endl; + std::cerr << "\t Contains: <" << token << ">" << std::endl; + temp_message.clear(); + full_message.clear(); + continue; // Probably a bad line cut } - input_messages >> temp_message; + while (std::getline(full_message, token, ' ')) { + token = SanitizeToken(token); + if (token.empty()) { continue; } + if (probability_dictionary[token] == 0) { + probability_dictionary[token] = 0.5; + } + if (is_ham_temp) { + probability_dictionary[token] += probability_dictionary[token] * 0.1; + } else { + probability_dictionary[token] -= probability_dictionary[token] * 0.1; + } + } + temp_message.clear(); + full_message.clear(); } } -void SMSMessageFilter::Filter(std::string file_name) { +void SMSMessageFilter::Prepare(std::string file_name) { // Open file std::ifstream input_messages(file_name); if (!input_messages.is_open()) { @@ -42,41 +65,130 @@ void SMSMessageFilter::Filter(std::string file_name) { } // Read in messages - MessageType temp_type; - std::string temp_message; + bool is_ham_temp; + std::string temp_message, token; + std::stringstream full_message; while (!input_messages.eof()) { - input_messages.get(&temp_message[0], 10, '\t'); - if (temp_message == "spam") { temp_type = kSpam; } - else if (temp_message == "ham") { temp_type = kHam; } + std::getline(input_messages, temp_message); + full_message.str(temp_message); + std::getline(full_message, token, '\t'); + if (token == "ham") { is_ham_temp = true; } + else if (token == "spam") { is_ham_temp = false; } else { std::cerr << "[SMSMessageFilter - Warning]" << " Could not determine message type" << std::endl; - std::cerr << "\t Contains: <" << temp_message << ">" << std::endl; + std::cerr << "\t Contains: <" << token << ">" << std::endl; + temp_message.clear(); + full_message.clear(); + continue; // Probably a bad line cut } - input_messages >> temp_message; - filtered_messages.emplace_back(temp_type, temp_message); + full_message.ignore('\t'); + filtered_messages.emplace_back(is_ham_temp, full_message.str()); + temp_message.clear(); + full_message.clear(); } } -void SMSMessageFilter::PrintReport(void) { +void SMSMessageFilter::Filter(void) { + double type_probability = 0.5; + std::string token; + std::stringstream full_message; + for (int i = 0; i < filtered_messages.size(); ++i) { + full_message.str(filtered_messages[i].message); + while (std::getline(full_message, token, ' ')) { + token = SanitizeToken(token); + if (probability_dictionary[token] == 0) { + probability_dictionary[token] = 0.5; + } + type_probability = probability_dictionary[token] * type_probability; + } + if (type_probability <= sentence_probability_ham) { + filtered_messages[i].is_ham_filter = true; + } else { filtered_messages[i].is_ham_filter = false; } + type_probability = 0.5; + } +} + +void SMSMessageFilter::Report(void) { + PrintReport(GenerateReport()); +} + +ReportData SMSMessageFilter::GenerateReport(void) { + double true_ham = 0; + double true_spam = 0; + double false_ham = 0; + double false_spam = 0; + for (SMSMessage message : filtered_messages) { + // Get total count + if (!(message.is_ham ^ message.is_ham_filter)) { + if (message.is_ham) { ++true_ham; } + else { ++true_spam; } + } else { + if (message.is_ham) { ++false_ham; } + else { ++false_spam; } + + } + } + std::cout << "[SMSMessageFilter - Info] Ham barrier: "; + std::cout << sentence_probability_ham << std::endl; + std::cout << "[SMSMessageFilter - Info] True ham: "; + std::cout << true_ham << std::endl; + std::cout << "[SMSMessageFilter - Info] True spam: "; + std::cout << true_spam << std::endl; + std::cout << "[SMSMessageFilter - Info] False ham: "; + std::cout << false_ham << std::endl; + std::cout << "[SMSMessageFilter - Info] False spam: "; + std::cout << false_spam << std::endl; + + // Calculate report data + ReportData new_report; + new_report.spam_precision = (true_ham) / (true_ham + false_ham); + new_report.spam_recall = (true_ham) / (true_ham + false_spam); + new_report.ham_precision = (true_spam) / (true_spam + false_spam); + new_report.ham_recall = (true_spam) / (true_spam + false_ham); + new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall); + new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall); + new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0; + return new_report; +} + +void SMSMessageFilter::PrintReport(ReportData report) { // Spam precision: (true positives) / (true positives + false positives) std::cout << "[SMSMessageFilter - Report] Spam precision: "; + std::cout << report.spam_precision << std::endl; // Spam recall: (true positives) / (true positives + false negatives) std::cout << "[SMSMessageFilter - Report] Spam recall: "; + std::cout << report.spam_recall << std::endl; // Ham precision: (true negatives) / (true negatives + false negatives) std::cout << "[SMSMessageFilter - Report] Ham precision: "; + std::cout << report.ham_precision << std::endl; // Ham recall: (true negatives) / (true negatives + false positives) std::cout << "[SMSMessageFilter - Report] Ham recall: "; + std::cout << report.ham_recall << std::endl; // Spam F-Score: 2* (spam precision * spam recall) / (spam precision + spam recall) std::cout << "[SMSMessageFilter - Report] Spam F-Score: "; + std::cout << report.spam_f_score << std::endl; // Ham F-Score: 2* (ham precision * ham recall) / (ham precision + ham recall) std::cout << "[SMSMessageFilter - Report] Ham F-Score: "; + std::cout << report.ham_f_score << std::endl; // Accuracy: (spam recall + ham recall) / 2 std::cout << "[SMSMessageFilter - Report] Accuracy: "; + std::cout << report.accuracy << std::endl; +} + +std::string SanitizeToken(std::string token) { + for (int i = 0; i < token.size(); ) { + if (token[i] < 'A' || token[i] > 'Z' + && token[i] < 'a' || token[i] > 'z') { token.erase(i, 1); } + else { ++i; } + } + std::transform(token.begin(), token.end(), token.begin(), + [](unsigned char c){ return std::tolower(c); }); + return token; } diff --git a/src/filter.hpp b/src/filter.hpp index 9555090..abc1f9f 100644 --- a/src/filter.hpp +++ b/src/filter.hpp @@ -5,17 +5,21 @@ #include #include -enum MessageType { - kSpam = 0, - kHam = 1, - kUnknown = 2, +struct ReportData { + double spam_precision; + double spam_recall; + double ham_precision; + double ham_recall; + double spam_f_score; + double ham_f_score; + double accuracy; }; struct SMSMessage { - SMSMessage(MessageType given_type, std::string given_message); - MessageType actual_type; + SMSMessage(bool given_type, std::string given_message); + bool is_ham; std::string message; - MessageType filter_type; + bool is_ham_filter; }; class SMSMessageFilter { @@ -23,15 +27,19 @@ public: SMSMessageFilter(void) = default; ~SMSMessageFilter(void) = default; void GenerateProbability(std::string file_name); - void Filter(std::string file_name); - void PrintReport(void); + void Prepare(std::string file_name); + void Filter(void); + void Report(void); private: double sentence_probability_ham = 0.5; // Spam is 1 - sentence_probability_ham std::map probability_dictionary; std::vector filtered_messages; + ReportData GenerateReport(void); + void PrintReport(ReportData report); }; +std::string SanitizeToken(std::string token); #endif // !FILTER_HPP diff --git a/src/main.cpp b/src/main.cpp index 29b2fa0..558150a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,10 @@ -#include +#include "filter.hpp" int main(void) { - std::cout << "Hello world" << std::endl; + SMSMessageFilter single_filter; + single_filter.GenerateProbability("test/SMSProbabilityGeneration.txt"); + single_filter.Prepare("test/SMSFilterTest.txt"); + single_filter.Filter(); + single_filter.Report(); return 0; }