generated from Trianta/cpp-unity-template
Update for completed project #1
@ -1,5 +1,6 @@
|
||||
add_executable(filter
|
||||
./main.cpp
|
||||
./filter.cpp
|
||||
)
|
||||
|
||||
target_include_directories(filter PUBLIC ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
249
src/filter.cpp
Normal file
249
src/filter.cpp
Normal file
@ -0,0 +1,249 @@
|
||||
#include "filter.hpp"
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
SMSMessage::SMSMessage(bool given_type, std::string given_message) {
|
||||
this->is_ham = given_type;
|
||||
this->message = given_message;
|
||||
}
|
||||
|
||||
// Takes in a file name from which it will generate a probability filter
|
||||
void SMSMessageFilter::GenerateProbability() {
|
||||
// Open file
|
||||
std::ifstream input_messages(generation_file_path);
|
||||
if (!input_messages.is_open()) {
|
||||
std::cerr << "[SMSMessageFilter - Error - Generator] ";
|
||||
std::cerr << "Error opening '" << generation_file_path << "'" << std::endl;
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Read in messages
|
||||
bool is_ham_temp;
|
||||
std::string temp_message, token;
|
||||
std::stringstream full_message;
|
||||
while (!input_messages.eof()) {
|
||||
std::getline(input_messages, temp_message);
|
||||
full_message.str(temp_message);
|
||||
std::getline(full_message, token, '\t');
|
||||
if (token == "ham") { is_ham_temp = true; }
|
||||
else if (token == "spam") { is_ham_temp = false; }
|
||||
else if (token.empty()) {
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
std::cerr << "[SMSMessageFilter - Warning]"
|
||||
<< " Could not determine message type (probably bad cut)"
|
||||
<< " ignoring..." << std::endl;
|
||||
std::cerr << "\t Contains: '" << token << "'" << std::endl;
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
continue; // Probably a bad line cut
|
||||
}
|
||||
while (std::getline(full_message, token, ' ')) {
|
||||
token = SanitizeToken(token);
|
||||
if (token.empty()) { continue; }
|
||||
if (is_ham_temp) {
|
||||
probability_dictionary[token].value += probability_dictionary[token].value * 0.99;
|
||||
if (probability_dictionary[token].value > 1.) {
|
||||
probability_dictionary[token].value = 1.;
|
||||
}
|
||||
} else {
|
||||
probability_dictionary[token].value -= probability_dictionary[token].value * 0.000000001;
|
||||
if (probability_dictionary[token].value <= 0.) {
|
||||
probability_dictionary[token].value = 0.0000000001;
|
||||
}
|
||||
}
|
||||
}
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void SMSMessageFilter::Prepare() {
|
||||
// Open file
|
||||
std::ifstream input_messages(filter_file_path);
|
||||
if (!input_messages.is_open()) {
|
||||
std::cerr << "[SMSMessageFilter - Error - Filter] ";
|
||||
std::cerr << "Error opening '" << filter_file_path << "'" << std::endl;
|
||||
std::exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
// Read in messages
|
||||
bool is_ham_temp;
|
||||
std::string temp_message, token;
|
||||
std::stringstream full_message;
|
||||
while (!input_messages.eof()) {
|
||||
std::getline(input_messages, temp_message);
|
||||
full_message.str(temp_message);
|
||||
std::getline(full_message, token, '\t');
|
||||
if (token == "ham") { is_ham_temp = true; }
|
||||
else if (token == "spam") { is_ham_temp = false; }
|
||||
else if (token.empty()) {
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
continue;
|
||||
}
|
||||
else {
|
||||
std::cerr << "[SMSMessageFilter - Warning]"
|
||||
<< " Could not determine message type (probably bad cut)"
|
||||
<< " ignoring..." << std::endl;
|
||||
std::cerr << "\t Contains: '" << token << "'" << std::endl;
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
continue; // Probably a bad line cut
|
||||
}
|
||||
filtered_messages.emplace_back(is_ham_temp, full_message.str());
|
||||
temp_message.clear();
|
||||
full_message.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void SMSMessageFilter::Filter(void) {
|
||||
double type_probability = 0.5;
|
||||
std::string token;
|
||||
std::stringstream full_message;
|
||||
for (int i = 0; i < filtered_messages.size(); ++i) {
|
||||
full_message.str(filtered_messages[i].message);
|
||||
while (std::getline(full_message, token, ' ')) {
|
||||
token = SanitizeToken(token);
|
||||
type_probability = probability_dictionary[token].value * type_probability;
|
||||
}
|
||||
double final_probability;
|
||||
final_probability = (1. - sentence_probability_ham) * type_probability;
|
||||
final_probability = final_probability / (final_probability + ((1. - type_probability) * sentence_probability_ham));
|
||||
if (final_probability >= sentence_probability_ham) {
|
||||
filtered_messages[i].is_ham_filter = true;
|
||||
} else { filtered_messages[i].is_ham_filter = false; }
|
||||
type_probability = 0.5;
|
||||
full_message.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void SMSMessageFilter::Report(void) {
|
||||
PrintReport(GenerateReport());
|
||||
}
|
||||
|
||||
void SMSMessageFilter::ReadArguments(int argc, char* argv[]) {
|
||||
if (argc == 1) {
|
||||
std::cerr << "Usage error: Please specify file to filter" << std::endl;
|
||||
PrintHelp();
|
||||
std::exit(1);
|
||||
}
|
||||
std::string argument_string;
|
||||
for (int i = 0; i < argc; ++i) {
|
||||
argument_string.assign(argv[i]);
|
||||
if (argument_string == "-g") {
|
||||
is_generator_defined = true;
|
||||
generation_file_path.assign(argv[i+1]);
|
||||
}
|
||||
if (argument_string == "-f") {
|
||||
is_input_defined = true;
|
||||
filter_file_path.assign(argv[i+1]);
|
||||
}
|
||||
}
|
||||
if (!is_input_defined) {
|
||||
std::cerr << "Usage error: Please specify file to filter" << std::endl;
|
||||
PrintHelp();
|
||||
std::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
ReportData SMSMessageFilter::GenerateReport(void) {
|
||||
double true_ham = 0.;
|
||||
double true_spam = 0.;
|
||||
double false_ham = 0.;
|
||||
double false_spam = 0.;
|
||||
for (SMSMessage message : filtered_messages) {
|
||||
// Get total count
|
||||
if (!(message.is_ham ^ message.is_ham_filter)) {
|
||||
if (message.is_ham) { ++true_ham; }
|
||||
else { ++true_spam; }
|
||||
} else {
|
||||
if (message.is_ham) { ++false_ham; }
|
||||
else { ++false_spam; }
|
||||
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << "[SMSMessageFilter - Info] Ham barrier: ";
|
||||
std::cout << sentence_probability_ham << std::endl;
|
||||
std::cout << "[SMSMessageFilter - Info] True ham: ";
|
||||
std::cout << true_ham << std::endl;
|
||||
std::cout << "[SMSMessageFilter - Info] True spam: ";
|
||||
std::cout << true_spam << std::endl;
|
||||
std::cout << "[SMSMessageFilter - Info] False ham: ";
|
||||
std::cout << false_ham << std::endl;
|
||||
std::cout << "[SMSMessageFilter - Info] False spam: ";
|
||||
std::cout << false_spam << std::endl;
|
||||
std::cout << std::endl;
|
||||
|
||||
// Calculate report data
|
||||
ReportData new_report;
|
||||
new_report.ham_precision = (true_ham) / (true_ham + false_ham);
|
||||
new_report.ham_recall = (true_ham) / (true_ham + false_spam);
|
||||
new_report.spam_precision = (true_spam) / (true_spam + false_spam);
|
||||
new_report.spam_recall = (true_spam) / (true_spam + false_ham);
|
||||
new_report.spam_f_score = 2.0 * (new_report.spam_precision * new_report.spam_recall) / (new_report.spam_precision + new_report.spam_recall);
|
||||
new_report.ham_f_score = 2.0 * (new_report.ham_precision * new_report.ham_recall) / (new_report.ham_precision + new_report.ham_recall);
|
||||
new_report.accuracy = (new_report.spam_recall + new_report.ham_recall) / 2.0;
|
||||
return new_report;
|
||||
}
|
||||
|
||||
void SMSMessageFilter::PrintReport(ReportData report) {
|
||||
std::cout << std::endl;
|
||||
std::cout << "============ [SMSMessageFilter - Report - Start] ============" << std::endl;
|
||||
// Spam precision: (true positives) / (true positives + false positives)
|
||||
std::cout << "[SMSMessageFilter - Report] Spam precision: ";
|
||||
std::cout << report.spam_precision << std::endl;
|
||||
|
||||
// Spam recall: (true positives) / (true positives + false negatives)
|
||||
std::cout << "[SMSMessageFilter - Report] Spam recall: ";
|
||||
std::cout << report.spam_recall << std::endl;
|
||||
|
||||
// Ham precision: (true negatives) / (true negatives + false negatives)
|
||||
std::cout << "[SMSMessageFilter - Report] Ham precision: ";
|
||||
std::cout << report.ham_precision << std::endl;
|
||||
|
||||
// Ham recall: (true negatives) / (true negatives + false positives)
|
||||
std::cout << "[SMSMessageFilter - Report] Ham recall: ";
|
||||
std::cout << report.ham_recall << std::endl;
|
||||
|
||||
// Spam F-Score: 2* (spam precision * spam recall) / (spam precision + spam recall)
|
||||
std::cout << "[SMSMessageFilter - Report] Spam F-Score: ";
|
||||
std::cout << report.spam_f_score << std::endl;
|
||||
|
||||
// Ham F-Score: 2* (ham precision * ham recall) / (ham precision + ham recall)
|
||||
std::cout << "[SMSMessageFilter - Report] Ham F-Score: ";
|
||||
std::cout << report.ham_f_score << std::endl;
|
||||
|
||||
// Accuracy: (spam recall + ham recall) / 2
|
||||
std::cout << "[SMSMessageFilter - Report] Accuracy: ";
|
||||
std::cout << report.accuracy << std::endl;
|
||||
std::cout << "============ [SMSMessageFilter - Report - End] ============" << std::endl;
|
||||
}
|
||||
|
||||
std::string SanitizeToken(std::string token) {
|
||||
for (int i = 0; i < token.size(); ) {
|
||||
if (token[i] < 'A' || token[i] > 'Z'
|
||||
&& token[i] < 'a' || token[i] > 'z') { token.erase(i, 1); }
|
||||
else { ++i; }
|
||||
}
|
||||
std::transform(token.begin(), token.end(), token.begin(),
|
||||
[](unsigned char c){ return std::tolower(c); });
|
||||
return token;
|
||||
}
|
||||
|
||||
void PrintHelp(void) {
|
||||
std::cerr << "Usage:\tfilter [-g file_for_probability] -f file_to_be_filtered" << std::endl;
|
||||
std::cerr << "\t -g description: Generates probabilities of words from this file" << std::endl;
|
||||
std::cerr << "\t\t (If not specified, then all words default to a 0.5 weight)" << std::endl;
|
||||
std::cerr << "\t -f description: File that will actually be used to filter from" << std::endl;
|
||||
std::cerr << "\t\t (Report is generated from this)" << std::endl;
|
||||
}
|
56
src/filter.hpp
Normal file
56
src/filter.hpp
Normal file
@ -0,0 +1,56 @@
|
||||
#ifndef FILTER_HPP
|
||||
#define FILTER_HPP
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// C++ maps don't allow specifying default value
|
||||
struct DoubleDefaultedToHalf {
|
||||
double value = 0.5;
|
||||
};
|
||||
|
||||
struct ReportData {
|
||||
double spam_precision;
|
||||
double spam_recall;
|
||||
double ham_precision;
|
||||
double ham_recall;
|
||||
double spam_f_score;
|
||||
double ham_f_score;
|
||||
double accuracy;
|
||||
};
|
||||
|
||||
struct SMSMessage {
|
||||
SMSMessage(bool given_type, std::string given_message);
|
||||
bool is_ham;
|
||||
std::string message;
|
||||
bool is_ham_filter;
|
||||
};
|
||||
|
||||
class SMSMessageFilter {
|
||||
public:
|
||||
SMSMessageFilter(void) = default;
|
||||
~SMSMessageFilter(void) = default;
|
||||
bool is_generator_defined = false;
|
||||
bool is_input_defined = false;
|
||||
void GenerateProbability();
|
||||
void Prepare();
|
||||
void Filter(void);
|
||||
void Report(void);
|
||||
void ReadArguments(int argc, char* argv[]);
|
||||
|
||||
private:
|
||||
double sentence_probability_ham = 0.2; // Sentence is spam if < this value
|
||||
std::map<std::string, DoubleDefaultedToHalf> probability_dictionary;
|
||||
std::string generation_file_path;
|
||||
std::string filter_file_path;
|
||||
std::vector<SMSMessage> filtered_messages;
|
||||
ReportData GenerateReport(void);
|
||||
void PrintReport(ReportData report);
|
||||
|
||||
};
|
||||
|
||||
std::string SanitizeToken(std::string token);
|
||||
void PrintHelp(void);
|
||||
|
||||
#endif // !FILTER_HPP
|
13
src/main.cpp
13
src/main.cpp
@ -1,6 +1,13 @@
|
||||
#include <iostream>
|
||||
#include "filter.hpp"
|
||||
|
||||
int main(void) {
|
||||
std::cout << "Hello world" << std::endl;
|
||||
int main(int argc, char* argv[]) {
|
||||
SMSMessageFilter single_filter;
|
||||
single_filter.ReadArguments(argc, argv);
|
||||
if (single_filter.is_generator_defined) {
|
||||
single_filter.GenerateProbability();
|
||||
}
|
||||
single_filter.Prepare();
|
||||
single_filter.Filter();
|
||||
single_filter.Report();
|
||||
return 0;
|
||||
}
|
||||
|
1414
test/SMSFilterTest-1.txt
Normal file
1414
test/SMSFilterTest-1.txt
Normal file
File diff suppressed because it is too large
Load Diff
1403
test/SMSFilterTest-2.txt
Normal file
1403
test/SMSFilterTest-2.txt
Normal file
File diff suppressed because it is too large
Load Diff
2816
test/SMSFilterTest.txt
Normal file
2816
test/SMSFilterTest.txt
Normal file
File diff suppressed because it is too large
Load Diff
1376
test/SMSProbabilityGeneration-1.txt
Normal file
1376
test/SMSProbabilityGeneration-1.txt
Normal file
File diff suppressed because it is too large
Load Diff
1384
test/SMSProbabilityGeneration-2.txt
Normal file
1384
test/SMSProbabilityGeneration-2.txt
Normal file
File diff suppressed because it is too large
Load Diff
2759
test/SMSProbabilityGeneration.txt
Normal file
2759
test/SMSProbabilityGeneration.txt
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user