Main part of filter working. Math seems to be wrong multiplying floats

This commit is contained in:
Trianta 2023-11-12 05:16:44 -06:00
parent d11ca9c23c
commit 9b3d373ee0
3 changed files with 154 additions and 30 deletions

View File

@ -1,10 +1,14 @@
#include "filter.hpp"
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
SMSMessage::SMSMessage(MessageType given_type, std::string given_message) {
this->actual_type = given_type;
SMSMessage::SMSMessage(bool given_type, std::string given_message) {
this->is_ham = given_type;
this->message = given_message;
}
@ -18,22 +22,41 @@ void SMSMessageFilter::GenerateProbability(std::string file_name) {
}
// Read in messages
MessageType temp_type;
std::string temp_message;
bool is_ham_temp;
std::string temp_message, token;
std::stringstream full_message;
while (!input_messages.eof()) {
input_messages.get(&temp_message[0], 10, '\t');
if (temp_message == "spam") { temp_type = kSpam; }
else if (temp_message == "ham") { temp_type = kHam; }
std::getline(input_messages, temp_message);
full_message.str(temp_message);
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else {
std::cerr << "[SMSMessageFilter - Warning]" <<
" Could not determine message type" << std::endl;
std::cerr << "\t Contains: <" << temp_message << ">" << std::endl;
std::cerr << "\t Contains: <" << token << ">" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
}
input_messages >> temp_message;
while (std::getline(full_message, token, ' ')) {
token = SanitizeToken(token);
if (token.empty()) { continue; }
if (probability_dictionary[token] == 0) {
probability_dictionary[token] = 0.5;
}
if (is_ham_temp) {
probability_dictionary[token] += probability_dictionary[token] * 0.1;
} else {
probability_dictionary[token] -= probability_dictionary[token] * 0.1;
}
}
temp_message.clear();
full_message.clear();
}
}
void SMSMessageFilter::Filter(std::string file_name) {
void SMSMessageFilter::Prepare(std::string file_name) {
// Open file
std::ifstream input_messages(file_name);
if (!input_messages.is_open()) {
@ -42,41 +65,130 @@ void SMSMessageFilter::Filter(std::string file_name) {
}
// Read in messages
MessageType temp_type;
std::string temp_message;
bool is_ham_temp;
std::string temp_message, token;
std::stringstream full_message;
while (!input_messages.eof()) {
input_messages.get(&temp_message[0], 10, '\t');
if (temp_message == "spam") { temp_type = kSpam; }
else if (temp_message == "ham") { temp_type = kHam; }
std::getline(input_messages, temp_message);
full_message.str(temp_message);
std::getline(full_message, token, '\t');
if (token == "ham") { is_ham_temp = true; }
else if (token == "spam") { is_ham_temp = false; }
else {
std::cerr << "[SMSMessageFilter - Warning]" <<
" Could not determine message type" << std::endl;
std::cerr << "\t Contains: <" << temp_message << ">" << std::endl;
std::cerr << "\t Contains: <" << token << ">" << std::endl;
temp_message.clear();
full_message.clear();
continue; // Probably a bad line cut
}
input_messages >> temp_message;
filtered_messages.emplace_back(temp_type, temp_message);
full_message.ignore('\t');
filtered_messages.emplace_back(is_ham_temp, full_message.str());
temp_message.clear();
full_message.clear();
}
}
void SMSMessageFilter::PrintReport(void) {
void SMSMessageFilter::Filter(void) {
double type_probability = 0.5;
std::string token;
std::stringstream full_message;
for (int i = 0; i < filtered_messages.size(); ++i) {
full_message.str(filtered_messages[i].message);
while (std::getline(full_message, token, ' ')) {
token = SanitizeToken(token);
if (probability_dictionary[token] == 0) {
probability_dictionary[token] = 0.5;
}
type_probability = probability_dictionary[token] * type_probability;
}
if (type_probability <= sentence_probability_ham) {
filtered_messages[i].is_ham_filter = true;
} else { filtered_messages[i].is_ham_filter = false; }
type_probability = 0.5;
}
}
void SMSMessageFilter::Report(void) {
PrintReport(GenerateReport());
}
ReportData SMSMessageFilter::GenerateReport(void) {
double true_ham = 0;
double true_spam = 0;
double false_ham = 0;
double false_spam = 0;
for (SMSMessage message : filtered_messages) {
// Get total count
if (!(message.is_ham ^ message.is_ham_filter)) {
if (message.is_ham) { ++true_ham; }
else { ++true_spam; }
} else {
if (message.is_ham) { ++false_ham; }
else { ++false_spam; }
}
}
std::cout << "[SMSMessageFilter - Info] Ham barrier: ";
std::cout << sentence_probability_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] True ham: ";
std::cout << true_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] True spam: ";
std::cout << true_spam << std::endl;
std::cout << "[SMSMessageFilter - Info] False ham: ";
std::cout << false_ham << std::endl;
std::cout << "[SMSMessageFilter - Info] False spam: ";
std::cout << false_spam << std::endl;
// Calculate report data
ReportData new_report;
new_report.spam_precision = (true_ham) / (true_ham + false_ham);
new_report.spam_recall = (true_ham) / (true_ham + false_spam);
new_report.ham_precision = (true_spam) / (true_spam + false_spam);
new_report.ham_recall = (true_spam) / (true_spam + false_ham);
new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall);
new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall);
new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0;
return new_report;
}
void SMSMessageFilter::PrintReport(ReportData report) {
// Spam precision: (true positives) / (true positives + false positives)
std::cout << "[SMSMessageFilter - Report] Spam precision: ";
std::cout << report.spam_precision << std::endl;
// Spam recall: (true positives) / (true positives + false negatives)
std::cout << "[SMSMessageFilter - Report] Spam recall: ";
std::cout << report.spam_recall << std::endl;
// Ham precision: (true negatives) / (true negatives + false negatives)
std::cout << "[SMSMessageFilter - Report] Ham precision: ";
std::cout << report.ham_precision << std::endl;
// Ham recall: (true negatives) / (true negatives + false positives)
std::cout << "[SMSMessageFilter - Report] Ham recall: ";
std::cout << report.ham_recall << std::endl;
// Spam F-Score: 2* (spam precision * spam recall) / (spam precision + spam recall)
std::cout << "[SMSMessageFilter - Report] Spam F-Score: ";
std::cout << report.spam_f_score << std::endl;
// Ham F-Score: 2* (ham precision * ham recall) / (ham precision + ham recall)
std::cout << "[SMSMessageFilter - Report] Ham F-Score: ";
std::cout << report.ham_f_score << std::endl;
// Accuracy: (spam recall + ham recall) / 2
std::cout << "[SMSMessageFilter - Report] Accuracy: ";
std::cout << report.accuracy << std::endl;
}
std::string SanitizeToken(std::string token) {
for (int i = 0; i < token.size(); ) {
if (token[i] < 'A' || token[i] > 'Z'
&& token[i] < 'a' || token[i] > 'z') { token.erase(i, 1); }
else { ++i; }
}
std::transform(token.begin(), token.end(), token.begin(),
[](unsigned char c){ return std::tolower(c); });
return token;
}

View File

@ -5,17 +5,21 @@
#include <string>
#include <vector>
enum MessageType {
kSpam = 0,
kHam = 1,
kUnknown = 2,
struct ReportData {
double spam_precision;
double spam_recall;
double ham_precision;
double ham_recall;
double spam_f_score;
double ham_f_score;
double accuracy;
};
struct SMSMessage {
SMSMessage(MessageType given_type, std::string given_message);
MessageType actual_type;
SMSMessage(bool given_type, std::string given_message);
bool is_ham;
std::string message;
MessageType filter_type;
bool is_ham_filter;
};
class SMSMessageFilter {
@ -23,15 +27,19 @@ public:
SMSMessageFilter(void) = default;
~SMSMessageFilter(void) = default;
void GenerateProbability(std::string file_name);
void Filter(std::string file_name);
void PrintReport(void);
void Prepare(std::string file_name);
void Filter(void);
void Report(void);
private:
double sentence_probability_ham = 0.5; // Spam is 1 - sentence_probability_ham
std::map<std::string, double> probability_dictionary;
std::vector<SMSMessage> filtered_messages;
ReportData GenerateReport(void);
void PrintReport(ReportData report);
};
std::string SanitizeToken(std::string token);
#endif // !FILTER_HPP

View File

@ -1,6 +1,10 @@
#include <iostream>
#include "filter.hpp"
int main(void) {
std::cout << "Hello world" << std::endl;
SMSMessageFilter single_filter;
single_filter.GenerateProbability("test/SMSProbabilityGeneration.txt");
single_filter.Prepare("test/SMSFilterTest.txt");
single_filter.Filter();
single_filter.Report();
return 0;
}