Update for completed project #1

Merged
Trianta merged 6 commits from dev into master 2023-11-12 20:28:49 -06:00
3 changed files with 84 additions and 27 deletions
Showing only changes of commit cb83e074dd - Show all commits

View File

@ -13,11 +13,12 @@ SMSMessage::SMSMessage(bool given_type, std::string given_message) {
}
// Takes in a file name from which it will generate a probability filter
void SMSMessageFilter::GenerateProbability(std::string file_name) {
void SMSMessageFilter::GenerateProbability() {
// Open file
std::ifstream input_messages(file_name);
std::ifstream input_messages(generation_file_path);
if (!input_messages.is_open()) {
std::cerr << "Error opening '" << file_name << "'" << std::endl;
std::cerr << "[SMSMessageFilter - Error - Generator] ";
std::cerr << "Error opening '" << generation_file_path << "'" << std::endl;
std::exit(EXIT_FAILURE);
}
@ -31,10 +32,16 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) {
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else if (token.empty()) {
temp_message.clear();
full_message.clear();
continue;
}
else {
std::cerr << "[SMSMessageFilter - Warning]" <<
" Could not determine message type" << std::endl;
std::cerr << "\t Contains: <" << token << ">" << std::endl;
std::cerr << "[SMSMessageFilter - Warning]"
<< " Could not determine message type (probably bad cut)"
<< " ignoring..." << std::endl;
std::cerr << "\t Contains: '" << token << "'" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
@ -43,12 +50,15 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) {
token = SanitizeToken(token);
if (token.empty()) { continue; }
if (is_ham_temp) {
probability_dictionary[token].value += probability_dictionary[token].value * 0.0000000001;
probability_dictionary[token].value += probability_dictionary[token].value * 0.99;
if (probability_dictionary[token].value > 1.) {
probability_dictionary[token].value = 1.;
}
} else {
probability_dictionary[token].value -= probability_dictionary[token].value * 0.0000000001;
probability_dictionary[token].value -= probability_dictionary[token].value * 0.000000001;
if (probability_dictionary[token].value <= 0.) {
probability_dictionary[token].value = 0.0000000001;
}
}
}
temp_message.clear();
@ -56,11 +66,12 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) {
}
}
void SMSMessageFilter::Prepare(std::string file_name) {
void SMSMessageFilter::Prepare() {
// Open file
std::ifstream input_messages(file_name);
std::ifstream input_messages(filter_file_path);
if (!input_messages.is_open()) {
std::cerr << "Error opening '" << file_name << "'" << std::endl;
std::cerr << "[SMSMessageFilter - Error - Filter] ";
std::cerr << "Error opening '" << filter_file_path << "'" << std::endl;
std::exit(EXIT_FAILURE);
}
@ -74,10 +85,16 @@ void SMSMessageFilter::Prepare(std::string file_name) {
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else if (token.empty()) {
temp_message.clear();
full_message.clear();
continue;
}
else {
std::cerr << "[SMSMessageFilter - Warning]" <<
" Could not determine message type" << std::endl;
std::cerr << "\t Contains: <" << token << ">" << std::endl;
std::cerr << "[SMSMessageFilter - Warning]"
<< " Could not determine message type (probably bad cut)"
<< " ignoring..." << std::endl;
std::cerr << "\t Contains: '" << token << "'" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
@ -101,11 +118,9 @@ void SMSMessageFilter::Filter(void) {
double final_probability;
final_probability = (1. - sentence_probability_ham) * type_probability;
final_probability = final_probability / (final_probability + ((1. - type_probability) * sentence_probability_ham));
if (type_probability <= sentence_probability_ham) {
if (final_probability >= sentence_probability_ham) {
filtered_messages[i].is_ham_filter = true;
} else { filtered_messages[i].is_ham_filter = false; }
std::cout << "[SMSMessageFilter - Info] Final probability of "
<< i << ": " << final_probability << std::endl;
type_probability = 0.5;
full_message.clear();
}
@ -115,6 +130,31 @@ void SMSMessageFilter::Report(void) {
PrintReport(GenerateReport());
}
void SMSMessageFilter::ReadArguments(int argc, char* argv[]) {
if (argc == 1) {
std::cerr << "Usage error: Please specify file to filter" << std::endl;
PrintHelp();
std::exit(1);
}
std::string argument_string;
for (int i = 0; i < argc; ++i) {
argument_string.assign(argv[i]);
if (argument_string == "-g") {
is_generator_defined = true;
generation_file_path.assign(argv[i+1]);
}
if (argument_string == "-f") {
is_input_defined = true;
filter_file_path.assign(argv[i+1]);
}
}
if (!is_input_defined) {
std::cerr << "Usage error: Please specify file to filter" << std::endl;
PrintHelp();
std::exit(1);
}
}
ReportData SMSMessageFilter::GenerateReport(void) {
double true_ham = 0.;
double true_spam = 0.;
@ -146,10 +186,10 @@ ReportData SMSMessageFilter::GenerateReport(void) {
// Calculate report data
ReportData new_report;
new_report.spam_precision = (true_ham) / (true_ham + false_ham);
new_report.spam_recall = (true_ham) / (true_ham + false_spam);
new_report.ham_precision = (true_spam) / (true_spam + false_spam);
new_report.ham_recall = (true_spam) / (true_spam + false_ham);
new_report.ham_precision = (true_ham) / (true_ham + false_ham);
new_report.ham_recall = (true_ham) / (true_ham + false_spam);
new_report.spam_precision = (true_spam) / (true_spam + false_spam);
new_report.spam_recall = (true_spam) / (true_spam + false_ham);
new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall);
new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall);
new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0;
@ -199,3 +239,11 @@ std::string SanitizeToken(std::string token) {
[](unsigned char c){ return std::tolower(c); });
return token;
}
void PrintHelp(void) {
std::cerr << "Usage:\tfilter [-g file_for_probability] -f file_to_be_filtered" << std::endl;
std::cerr << "\t -g description: Generates probabilities of words from this file" << std::endl;
std::cerr << "\t\t (If not specified, then all words default to a 0.5 weight)" << std::endl;
std::cerr << "\t -f description: File that will actually be used to filter from" << std::endl;
std::cerr << "\t\t (Report is generated from this)" << std::endl;
}

View File

@ -31,14 +31,19 @@ class SMSMessageFilter {
public:
SMSMessageFilter(void) = default;
~SMSMessageFilter(void) = default;
void GenerateProbability(std::string file_name);
void Prepare(std::string file_name);
bool is_generator_defined = false;
bool is_input_defined = false;
void GenerateProbability();
void Prepare();
void Filter(void);
void Report(void);
void ReadArguments(int argc, char* argv[]);
private:
double sentence_probability_ham = 0.5; // Spam is 1 - sentence_probability_ham
double sentence_probability_ham = 0.2; // Sentence is spam if < this value
std::map<std::string, DoubleDefaultedToHalf> probability_dictionary;
std::string generation_file_path;
std::string filter_file_path;
std::vector<SMSMessage> filtered_messages;
ReportData GenerateReport(void);
void PrintReport(ReportData report);
@ -46,5 +51,6 @@ private:
};
std::string SanitizeToken(std::string token);
void PrintHelp(void);
#endif // !FILTER_HPP

View File

@ -1,9 +1,12 @@
#include "filter.hpp"
int main(void) {
int main(int argc, char* argv[]) {
SMSMessageFilter single_filter;
single_filter.GenerateProbability("test/SMSProbabilityGeneration.txt");
single_filter.Prepare("test/SMSFilterTest.txt");
single_filter.ReadArguments(argc, argv);
if (single_filter.is_generator_defined) {
single_filter.GenerateProbability();
}
single_filter.Prepare();
single_filter.Filter();
single_filter.Report();
return 0;