Update for completed project #1

Merged
Trianta merged 6 commits from dev into master 2023-11-12 20:28:49 -06:00
10 changed files with 11468 additions and 3 deletions

View File

@ -1,5 +1,6 @@
add_executable(filter
./main.cpp
./filter.cpp
)
target_include_directories(filter PUBLIC ${CMAKE_CURRENT_LIST_DIR})

249
src/filter.cpp Normal file
View File

@ -0,0 +1,249 @@
#include "filter.hpp"
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
SMSMessage::SMSMessage(bool given_type, std::string given_message) {
this->is_ham = given_type;
this->message = given_message;
}
// Takes in a file name from which it will generate a probability filter
void SMSMessageFilter::GenerateProbability() {
// Open file
std::ifstream input_messages(generation_file_path);
if (!input_messages.is_open()) {
std::cerr << "[SMSMessageFilter - Error - Generator] ";
std::cerr << "Error opening '" << generation_file_path << "'" << std::endl;
std::exit(EXIT_FAILURE);
}
// Read in messages
bool is_ham_temp;
std::string temp_message, token;
std::stringstream full_message;
while (!input_messages.eof()) {
std::getline(input_messages, temp_message);
full_message.str(temp_message);
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else if (token.empty()) {
temp_message.clear();
full_message.clear();
continue;
}
else {
std::cerr << "[SMSMessageFilter - Warning]"
<< " Could not determine message type (probably bad cut)"
<< " ignoring..." << std::endl;
std::cerr << "\t Contains: '" << token << "'" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
}
while (std::getline(full_message, token, ' ')) {
token = SanitizeToken(token);
if (token.empty()) { continue; }
if (is_ham_temp) {
probability_dictionary[token].value += probability_dictionary[token].value * 0.99;
if (probability_dictionary[token].value > 1.) {
probability_dictionary[token].value = 1.;
}
} else {
probability_dictionary[token].value -= probability_dictionary[token].value * 0.000000001;
if (probability_dictionary[token].value <= 0.) {
probability_dictionary[token].value = 0.0000000001;
}
}
}
temp_message.clear();
full_message.clear();
}
}
void SMSMessageFilter::Prepare() {
// Open file
std::ifstream input_messages(filter_file_path);
if (!input_messages.is_open()) {
std::cerr << "[SMSMessageFilter - Error - Filter] ";
std::cerr << "Error opening '" << filter_file_path << "'" << std::endl;
std::exit(EXIT_FAILURE);
}
// Read in messages
bool is_ham_temp;
std::string temp_message, token;
std::stringstream full_message;
while (!input_messages.eof()) {
std::getline(input_messages, temp_message);
full_message.str(temp_message);
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else if (token.empty()) {
temp_message.clear();
full_message.clear();
continue;
}
else {
std::cerr << "[SMSMessageFilter - Warning]"
<< " Could not determine message type (probably bad cut)"
<< " ignoring..." << std::endl;
std::cerr << "\t Contains: '" << token << "'" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
}
filtered_messages.emplace_back(is_ham_temp, full_message.str());
temp_message.clear();
full_message.clear();
}
}
void SMSMessageFilter::Filter(void) {
double type_probability = 0.5;
std::string token;
std::stringstream full_message;
for (int i = 0; i < filtered_messages.size(); ++i) {
full_message.str(filtered_messages[i].message);
while (std::getline(full_message, token, ' ')) {
token = SanitizeToken(token);
type_probability = probability_dictionary[token].value * type_probability;
}
double final_probability;
final_probability = (1. - sentence_probability_ham) * type_probability;
final_probability = final_probability / (final_probability + ((1. - type_probability) * sentence_probability_ham));
if (final_probability >= sentence_probability_ham) {
filtered_messages[i].is_ham_filter = true;
} else { filtered_messages[i].is_ham_filter = false; }
type_probability = 0.5;
full_message.clear();
}
}
void SMSMessageFilter::Report(void) {
PrintReport(GenerateReport());
}
void SMSMessageFilter::ReadArguments(int argc, char* argv[]) {
if (argc == 1) {
std::cerr << "Usage error: Please specify file to filter" << std::endl;
PrintHelp();
std::exit(1);
}
std::string argument_string;
for (int i = 0; i < argc; ++i) {
argument_string.assign(argv[i]);
if (argument_string == "-g") {
is_generator_defined = true;
generation_file_path.assign(argv[i+1]);
}
if (argument_string == "-f") {
is_input_defined = true;
filter_file_path.assign(argv[i+1]);
}
}
if (!is_input_defined) {
std::cerr << "Usage error: Please specify file to filter" << std::endl;
PrintHelp();
std::exit(1);
}
}
ReportData SMSMessageFilter::GenerateReport(void) {
double true_ham = 0.;
double true_spam = 0.;
double false_ham = 0.;
double false_spam = 0.;
for (SMSMessage message : filtered_messages) {
// Get total count
if (!(message.is_ham ^ message.is_ham_filter)) {
if (message.is_ham) { ++true_ham; }
else { ++true_spam; }
} else {
if (message.is_ham) { ++false_ham; }
else { ++false_spam; }
}
}
std::cout << std::endl;
std::cout << "[SMSMessageFilter - Info] Ham barrier: ";
std::cout << sentence_probability_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] True ham: ";
std::cout << true_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] True spam: ";
std::cout << true_spam << std::endl;
std::cout << "[SMSMessageFilter - Info] False ham: ";
std::cout << false_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] False spam: ";
std::cout << false_spam << std::endl;
std::cout << std::endl;
// Calculate report data
ReportData new_report;
new_report.ham_precision = (true_ham) / (true_ham + false_ham);
new_report.ham_recall = (true_ham) / (true_ham + false_spam);
new_report.spam_precision = (true_spam) / (true_spam + false_spam);
new_report.spam_recall = (true_spam) / (true_spam + false_ham);
new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall);
new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall);
new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0;
return new_report;
}
void SMSMessageFilter::PrintReport(ReportData report) {
std::cout << std::endl;
std::cout << "============ [SMSMessageFilter - Report - Start] ============" << std::endl;
// Spam precision: (true positives) / (true positives + false positives)
std::cout << "[SMSMessageFilter - Report] Spam precision: ";
std::cout << report.spam_precision << std::endl;
// Spam recall: (true positives) / (true positives + false negatives)
std::cout << "[SMSMessageFilter - Report] Spam recall: ";
std::cout << report.spam_recall << std::endl;
// Ham precision: (true negatives) / (true negatives + false negatives)
std::cout << "[SMSMessageFilter - Report] Ham precision: ";
std::cout << report.ham_precision << std::endl;
// Ham recall: (true negatives) / (true negatives + false positives)
std::cout << "[SMSMessageFilter - Report] Ham recall: ";
std::cout << report.ham_recall << std::endl;
// Spam F-Score: 2* (spam precision * spam recall) / (spam precision + spam recall)
std::cout << "[SMSMessageFilter - Report] Spam F-Score: ";
std::cout << report.spam_f_score << std::endl;
// Ham F-Score: 2* (ham precision * ham recall) / (ham precision + ham recall)
std::cout << "[SMSMessageFilter - Report] Ham F-Score: ";
std::cout << report.ham_f_score << std::endl;
// Accuracy: (spam recall + ham recall) / 2
std::cout << "[SMSMessageFilter - Report] Accuracy: ";
std::cout << report.accuracy << std::endl;
std::cout << "============ [SMSMessageFilter - Report - End] ============" << std::endl;
}
std::string SanitizeToken(std::string token) {
for (int i = 0; i < token.size(); ) {
if (token[i] < 'A' || token[i] > 'Z'
&& token[i] < 'a' || token[i] > 'z') { token.erase(i, 1); }
else { ++i; }
}
std::transform(token.begin(), token.end(), token.begin(),
[](unsigned char c){ return std::tolower(c); });
return token;
}
void PrintHelp(void) {
std::cerr << "Usage:\tfilter [-g file_for_probability] -f file_to_be_filtered" << std::endl;
std::cerr << "\t -g description: Generates probabilities of words from this file" << std::endl;
std::cerr << "\t\t (If not specified, then all words default to a 0.5 weight)" << std::endl;
std::cerr << "\t -f description: File that will actually be used to filter from" << std::endl;
std::cerr << "\t\t (Report is generated from this)" << std::endl;
}

56
src/filter.hpp Normal file
View File

@ -0,0 +1,56 @@
#ifndef FILTER_HPP
#define FILTER_HPP
#include <map>
#include <string>
#include <vector>
// C++ maps don't allow specifying default value
struct DoubleDefaultedToHalf {
double value = 0.5;
};
struct ReportData {
double spam_precision;
double spam_recall;
double ham_precision;
double ham_recall;
double spam_f_score;
double ham_f_score;
double accuracy;
};
struct SMSMessage {
SMSMessage(bool given_type, std::string given_message);
bool is_ham;
std::string message;
bool is_ham_filter;
};
class SMSMessageFilter {
public:
SMSMessageFilter(void) = default;
~SMSMessageFilter(void) = default;
bool is_generator_defined = false;
bool is_input_defined = false;
void GenerateProbability();
void Prepare();
void Filter(void);
void Report(void);
void ReadArguments(int argc, char* argv[]);
private:
double sentence_probability_ham = 0.2; // Sentence is spam if < this value
std::map<std::string, DoubleDefaultedToHalf> probability_dictionary;
std::string generation_file_path;
std::string filter_file_path;
std::vector<SMSMessage> filtered_messages;
ReportData GenerateReport(void);
void PrintReport(ReportData report);
};
std::string SanitizeToken(std::string token);
void PrintHelp(void);
#endif // !FILTER_HPP

View File

@ -1,6 +1,13 @@
#include <iostream>
#include "filter.hpp"
int main(void) {
std::cout << "Hello world" << std::endl;
int main(int argc, char* argv[]) {
SMSMessageFilter single_filter;
single_filter.ReadArguments(argc, argv);
if (single_filter.is_generator_defined) {
single_filter.GenerateProbability();
}
single_filter.Prepare();
single_filter.Filter();
single_filter.Report();
return 0;
}

1414
test/SMSFilterTest-1.txt Normal file

File diff suppressed because it is too large Load Diff

1403
test/SMSFilterTest-2.txt Normal file

File diff suppressed because it is too large Load Diff

2816
test/SMSFilterTest.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff