From cb83e074dd70317e9090fe043829bfbfd1259ee1 Mon Sep 17 00:00:00 2001 From: Trianta <56975502+Trimutex@users.noreply.github.com> Date: Sun, 12 Nov 2023 20:28:01 -0600 Subject: [PATCH] Added argument input for files --- src/filter.cpp | 90 ++++++++++++++++++++++++++++++++++++++------------ src/filter.hpp | 12 +++++-- src/main.cpp | 9 +++-- 3 files changed, 84 insertions(+), 27 deletions(-) diff --git a/src/filter.cpp b/src/filter.cpp index be36846..946f74d 100644 --- a/src/filter.cpp +++ b/src/filter.cpp @@ -13,11 +13,12 @@ SMSMessage::SMSMessage(bool given_type, std::string given_message) { } // Takes in a file name from which it will generate a probability filter -void SMSMessageFilter::GenerateProbability(std::string file_name) { +void SMSMessageFilter::GenerateProbability() { // Open file - std::ifstream input_messages(file_name); + std::ifstream input_messages(generation_file_path); if (!input_messages.is_open()) { - std::cerr << "Error opening '" << file_name << "'" << std::endl; + std::cerr << "[SMSMessageFilter - Error - Generator] "; + std::cerr << "Error opening '" << generation_file_path << "'" << std::endl; std::exit(EXIT_FAILURE); } @@ -31,10 +32,16 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) { std::getline(full_message, token, '\t'); if (token == "ham") { is_ham_temp = true; } else if (token == "spam") { is_ham_temp = false; } + else if (token.empty()) { + temp_message.clear(); + full_message.clear(); + continue; + } else { - std::cerr << "[SMSMessageFilter - Warning]" << - " Could not determine message type" << std::endl; - std::cerr << "\t Contains: <" << token << ">" << std::endl; + std::cerr << "[SMSMessageFilter - Warning]" + << " Could not determine message type (probably bad cut)" + << " ignoring..." << std::endl; + std::cerr << "\t Contains: '" << token << "'" << std::endl; temp_message.clear(); full_message.clear(); continue; // Probably a bad line cut @@ -43,12 +50,15 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) { token = SanitizeToken(token); if (token.empty()) { continue; } if (is_ham_temp) { - probability_dictionary[token].value += probability_dictionary[token].value * 0.0000000001; + probability_dictionary[token].value += probability_dictionary[token].value * 0.99; if (probability_dictionary[token].value > 1.) { probability_dictionary[token].value = 1.; } } else { - probability_dictionary[token].value -= probability_dictionary[token].value * 0.0000000001; + probability_dictionary[token].value -= probability_dictionary[token].value * 0.000000001; + if (probability_dictionary[token].value <= 0.) { + probability_dictionary[token].value = 0.0000000001; + } } } temp_message.clear(); @@ -56,11 +66,12 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) { } } -void SMSMessageFilter::Prepare(std::string file_name) { +void SMSMessageFilter::Prepare() { // Open file - std::ifstream input_messages(file_name); + std::ifstream input_messages(filter_file_path); if (!input_messages.is_open()) { - std::cerr << "Error opening '" << file_name << "'" << std::endl; + std::cerr << "[SMSMessageFilter - Error - Filter] "; + std::cerr << "Error opening '" << filter_file_path << "'" << std::endl; std::exit(EXIT_FAILURE); } @@ -74,10 +85,16 @@ void SMSMessageFilter::Prepare(std::string file_name) { std::getline(full_message, token, '\t'); if (token == "ham") { is_ham_temp = true; } else if (token == "spam") { is_ham_temp = false; } + else if (token.empty()) { + temp_message.clear(); + full_message.clear(); + continue; + } else { - std::cerr << "[SMSMessageFilter - Warning]" << - " Could not determine message type" << std::endl; - std::cerr << "\t Contains: <" << token << ">" << std::endl; + std::cerr << "[SMSMessageFilter - Warning]" + << " Could not determine message type (probably bad cut)" + << " ignoring..." << std::endl; + std::cerr << "\t Contains: '" << token << "'" << std::endl; temp_message.clear(); full_message.clear(); continue; // Probably a bad line cut @@ -101,11 +118,9 @@ void SMSMessageFilter::Filter(void) { double final_probability; final_probability = (1. - sentence_probability_ham) * type_probability; final_probability = final_probability / (final_probability + ((1. - type_probability) * sentence_probability_ham)); - if (type_probability <= sentence_probability_ham) { + if (final_probability >= sentence_probability_ham) { filtered_messages[i].is_ham_filter = true; } else { filtered_messages[i].is_ham_filter = false; } - std::cout << "[SMSMessageFilter - Info] Final probability of " - << i << ": " << final_probability << std::endl; type_probability = 0.5; full_message.clear(); } @@ -115,6 +130,31 @@ void SMSMessageFilter::Report(void) { PrintReport(GenerateReport()); } +void SMSMessageFilter::ReadArguments(int argc, char* argv[]) { + if (argc == 1) { + std::cerr << "Usage error: Please specify file to filter" << std::endl; + PrintHelp(); + std::exit(1); + } + std::string argument_string; + for (int i = 0; i < argc; ++i) { + argument_string.assign(argv[i]); + if (argument_string == "-g") { + is_generator_defined = true; + generation_file_path.assign(argv[i+1]); + } + if (argument_string == "-f") { + is_input_defined = true; + filter_file_path.assign(argv[i+1]); + } + } + if (!is_input_defined) { + std::cerr << "Usage error: Please specify file to filter" << std::endl; + PrintHelp(); + std::exit(1); + } +} + ReportData SMSMessageFilter::GenerateReport(void) { double true_ham = 0.; double true_spam = 0.; @@ -146,10 +186,10 @@ ReportData SMSMessageFilter::GenerateReport(void) { // Calculate report data ReportData new_report; - new_report.spam_precision = (true_ham) / (true_ham + false_ham); - new_report.spam_recall = (true_ham) / (true_ham + false_spam); - new_report.ham_precision = (true_spam) / (true_spam + false_spam); - new_report.ham_recall = (true_spam) / (true_spam + false_ham); + new_report.ham_precision = (true_ham) / (true_ham + false_ham); + new_report.ham_recall = (true_ham) / (true_ham + false_spam); + new_report.spam_precision = (true_spam) / (true_spam + false_spam); + new_report.spam_recall = (true_spam) / (true_spam + false_ham); new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall); new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall); new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0; @@ -199,3 +239,11 @@ std::string SanitizeToken(std::string token) { [](unsigned char c){ return std::tolower(c); }); return token; } + +void PrintHelp(void) { + std::cerr << "Usage:\tfilter [-g file_for_probability] -f file_to_be_filtered" << std::endl; + std::cerr << "\t -g description: Generates probabilities of words from this file" << std::endl; + std::cerr << "\t\t (If not specified, then all words default to a 0.5 weight)" << std::endl; + std::cerr << "\t -f description: File that will actually be used to filter from" << std::endl; + std::cerr << "\t\t (Report is generated from this)" << std::endl; +} diff --git a/src/filter.hpp b/src/filter.hpp index 31a1579..d3aaf56 100644 --- a/src/filter.hpp +++ b/src/filter.hpp @@ -31,14 +31,19 @@ class SMSMessageFilter { public: SMSMessageFilter(void) = default; ~SMSMessageFilter(void) = default; - void GenerateProbability(std::string file_name); - void Prepare(std::string file_name); + bool is_generator_defined = false; + bool is_input_defined = false; + void GenerateProbability(); + void Prepare(); void Filter(void); void Report(void); + void ReadArguments(int argc, char* argv[]); private: - double sentence_probability_ham = 0.5; // Spam is 1 - sentence_probability_ham + double sentence_probability_ham = 0.2; // Sentence is spam if < this value std::map probability_dictionary; + std::string generation_file_path; + std::string filter_file_path; std::vector filtered_messages; ReportData GenerateReport(void); void PrintReport(ReportData report); @@ -46,5 +51,6 @@ private: }; std::string SanitizeToken(std::string token); +void PrintHelp(void); #endif // !FILTER_HPP diff --git a/src/main.cpp b/src/main.cpp index 558150a..8d161ba 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,9 +1,12 @@ #include "filter.hpp" -int main(void) { +int main(int argc, char* argv[]) { SMSMessageFilter single_filter; - single_filter.GenerateProbability("test/SMSProbabilityGeneration.txt"); - single_filter.Prepare("test/SMSFilterTest.txt"); + single_filter.ReadArguments(argc, argv); + if (single_filter.is_generator_defined) { + single_filter.GenerateProbability(); + } + single_filter.Prepare(); single_filter.Filter(); single_filter.Report(); return 0;