add oneR to main #2

Merged
Trianta merged 8 commits from oneR into main 2024-04-17 21:35:56 -05:00
3 changed files with 80 additions and 4 deletions
Showing only changes of commit 182eed641b - Show all commits

View File

@ -41,6 +41,10 @@ namespace ARFF {
this->values.resize(size);
}
AttributeEvaluation::AttributeEvaluation(AttributeType *attribute) {
this->currentAttribute = attribute;
}
// Read entire data file and parse it
void Arff::Read(std::string filename) {
std::ifstream dataFile(filename);
@ -101,9 +105,17 @@ namespace ARFF {
// Print result of applying OneR
// TODO: Create function
void PrintOneR(void) {
void Arff::OneR(void) {
AttributeEvaluation bestAttribute = _OneR();
debug::Log(kNone, "***Best 1-rule***");
debug::Log(kNone, "\t" + bestAttribute.currentAttribute->attribute + ':');
for (auto it : bestAttribute.rules) {
debug::Log(kNone, "\t\t" + it.first + " ---> " + it.second);
}
debug::Log(kNone, "Error rate: " + std::to_string(bestAttribute.totalError) + "/" + std::to_string(database.size()));
}
// Add the attribute to the list
void Arff::AddAttribute(std::string line) {
std::stringstream parser(line);
@ -182,4 +194,57 @@ namespace ARFF {
}
debug::Log(kLog, "All values exist, continuing...");
}
// Perform OneR on data that was previously read in
AttributeEvaluation Arff::_OneR(void) {
AttributeEvaluation bestEvaluation;
bestEvaluation.totalErrorRate = 1.0f;
// -1 used for ignoring test rule (eg, play=yes/no)
for (int i = 0; i < attributeList.size() - 1; ++i) {
AttributeEvaluation evaluation = EvaluateAttribute(&attributeList[i], i);
if (evaluation.totalErrorRate < bestEvaluation.totalErrorRate) {
bestEvaluation = evaluation;
bestEvaluation.currentAttribute = evaluation.currentAttribute;
}
debug::Log(kLog, "Evaluation on " + evaluation.currentAttribute->attribute + " completed");
}
return bestEvaluation;
}
// Determine error rate and best option for each value of an attribute
// Originally set up to use OneR
AttributeEvaluation Arff::EvaluateAttribute(AttributeType *attribute, const int attributePos) {
AttributeEvaluation evaluation(attribute);
std::map<std::string, int> results;
for (std::string value : attributeList.end()->values) {
results.emplace(value, 0);
}
for (int i = 0; i < attribute->values.size(); ++i) {
if (attribute->values[i] == "?") { continue; }
for (auto instance = database.begin(); instance != database.end(); ++instance) {
if (instance->values[attributePos] != attribute->values[i]) { continue; }
++results[instance->values.back()];
}
debug::Log(kTrace, "Results:");
for (auto it : results) { debug::Log(kTrace, "\t" + it.first + ": " + std::to_string(it.second)); }
int lowest = 9999;
std::string bestResult = results.begin()->first;
for (auto it = results.begin(); it != results.end(); ++it) {
if (it->second < lowest) {
lowest = it->second;
} else {
bestResult = it->first;
}
}
evaluation.rules.emplace(attribute->values[i], bestResult);
evaluation.totalError += lowest;
debug::Log(kLog, "Added rule " + attribute->values[i] + "->" + bestResult);
// Reset
for (auto it = results.begin(); it != results.end(); ++it) {
it->second = 0;
}
}
evaluation.totalErrorRate = evaluation.totalError / float(database.size());
return evaluation;
}
}

View File

@ -3,6 +3,7 @@
#include <string>
#include <vector>
#include <map>
namespace ARFF {
void ParseArguments(int argc, char* argv[]);
@ -23,13 +24,22 @@ namespace ARFF {
std::vector<std::string> values;
};
struct AttributeEvaluation {
AttributeEvaluation() = default;
AttributeEvaluation(AttributeType *attribute);
AttributeType *currentAttribute;
std::map<std::string, std::string> rules;
float totalErrorRate = 0.0f;
int totalError = 0;
};
class Arff {
public:
Arff() = default;
void Read(std::string filename);
void PrintOverview(void);
void PrintData(void);
void PrintOneR(void);
void OneR(void);
private:
std::string relation;
std::vector<AttributeType> attributeList;
@ -37,6 +47,8 @@ namespace ARFF {
void AddAttribute(std::string line);
void AddData(std::string line);
void TestIntegrity(void);
AttributeEvaluation _OneR(void);
AttributeEvaluation EvaluateAttribute(AttributeType *attribute, const int attributePos);
};
}

View File

@ -4,12 +4,11 @@
* Description: Read and store ARFF data from a file
*/
#include "arff.hpp"
#include "log.hpp"
int main(int argc, char* argv[]) {
ARFF::ParseArguments(argc, argv);
ARFF::Arff data;
data.Read(ARFF::GetDataFilename());
data.PrintOverview();
debug::Log(kLog, "Test");
data.OneR();
}