Merge pull request 'dev' (#12) from dev into master

Reviewed-on: #12
This commit is contained in:
Trianta 2023-12-08 22:52:22 -06:00
commit 4f14763875
3 changed files with 93 additions and 54 deletions

View File

@ -2,9 +2,12 @@
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <iterator>
#include <random> #include <random>
#include <sstream>
std::default_random_engine generator; std::default_random_engine generator;
void InitializeGenerator(void) void InitializeGenerator(void)
{ {
generator.seed(std::random_device{}()); generator.seed(std::random_device{}());
@ -15,40 +18,65 @@ int GenerateRandomNumber(int generationLimit)
{ {
int generatedNumber; int generatedNumber;
std::uniform_int_distribution<> distribution(0, generationLimit - 1); std::uniform_int_distribution<> distribution(0, generationLimit - 1);
generatedNumber = distribution(snakeplusplus::generator); generatedNumber = distribution(generator);
return generatedNumber; return generatedNumber;
} }
TrieNode::TrieNode(void) {
isEndOfWord = false;
}
void Trie::insert(const std::deque<char>& currentKGram) { Trie::Trie(void) {
root = new TrieNode();
}
void Trie::Insert(std::string kgram) {
TrieNode* current = root; TrieNode* current = root;
for (char ch : currentKGram) { for (char ch : kgram) {
if (current->children.find(ch) == current->children.end()) { if (current->children.find(ch) == current->children.end()) {
current->children[ch] = new TrieNode(); current->children[ch] = new TrieNode();
} else { ++current->occurances; } } else { ++current->occurrences; }
Recalculate(current->children);
current = current->children[ch]; current = current->children[ch];
} }
current->isEndOfWord = true; current->isEndOfWord = true;
} }
bool Trie::search(const std::deque<char>& currentKGram) const { char Trie::GetNextCharacter(const std::string& kgram) {
TrieNode* current = root; TrieNode* current = root;
char nextChar = ' ';
for (char ch : currentKGram) { for (char ch: kgram) {
if (current->children.find(ch) == current->children.end()) { auto it = current->children.find(ch);
return false; if (it == current->children.end()) {
return nextChar;
} }
current = it->second;
current = current->children[ch];
} }
double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
return current->isEndOfWord; double minimum = 0.;
for (const auto& i : current->children) {
minimum += i.second->probability;
if (roll <= minimum) {
nextChar = i.first;
break;
}
}
return nextChar;
} }
void Generator::SetArguments(int argc, char* argv[]) { void Trie::Recalculate(std::unordered_map<char, TrieNode*> children) {
int total = 0;
for (auto i : children) {
total += i.second->occurrences;
}
for (auto i : children) {
i.second->probability = ((double)i.second->occurrences) / ((double) total);
}
}
void Generator::SetArguments(const int argc, char* argv[]) {
std::string tempStr; std::string tempStr;
for (int i = 1; i < argc; i += 2) { for (int i = 1; i < argc; i += 2) {
tempStr.assign(argv[i]); tempStr.assign(argv[i]);
@ -58,11 +86,11 @@ void Generator::SetArguments(int argc, char* argv[]) {
} }
if (tempStr == "-k") { if (tempStr == "-k") {
setup.isPrefixSet = true; setup.isPrefixSet = true;
setup.prefixLength = std::stoi(argv[i+1]); setup.prefixLength = std::stoi(argv[i+1]) + 1;
} }
if (tempStr == "-n") { if (tempStr == "-n") {
setup.isOutputSet = true; setup.isOutputSet = true;
setup.outputLength = std::stoi(argv[i+1]); setup.outputLength = std::stoi(argv[i+1]) + 1;
} }
if (tempStr == "-h") { if (tempStr == "-h") {
PrintUsage(); PrintUsage();
@ -75,35 +103,47 @@ void Generator::SetArguments(int argc, char* argv[]) {
if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); } if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); }
} }
void Generator::ReadFile(void) { void Generator::Train(void) {
std::ifstream inputFile(setup.filename); std::ifstream inputFile(setup.filename);
if (!inputFile.is_open()) { if (!inputFile.is_open()) {
std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl; std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl;
exit(1); exit(1);
} }
std::deque<char> currentKGram; std::stringstream iss;
char tempChar; iss << inputFile.rdbuf();
// Read in first k-gram std::vector<std::string> words(std::istream_iterator<std::string>{iss},
{ std::istream_iterator<std::string>());
std::string initializeKGram;
inputFile.get(&initializeKGram[0], setup.prefixLength); std::cout << "[Setup - Info] Begin training" << std::endl;
for (char ch : initializeKGram) { currentKGram.emplace_back(ch); } for (const auto& word : words) {
trie.insert(currentKGram); if (word.size() < setup.prefixLength) {
trie.Insert(word);
continue;
} else {
for (int i = 0; i < word.size() - setup.prefixLength; ++i) {
trie.Insert(word.substr(i, setup.prefixLength));
} }
// Read rest of file
while (inputFile.get(tempChar)) {
currentKGram.emplace_back(tempChar);
if (currentKGram.size() > setup.prefixLength) { currentKGram.pop_front(); }
trie.insert(currentKGram);
} }
}
std::cout << "[Setup - Info] Finished training" << std::endl;
} }
void Generator::GenerateOutput(void) { void Generator::Generation(void) {
} std::cout << "[Generation - Info] Output start" << std::endl;
for (int i = 0; i < setup.outputLength; ++i) {
char Generator::GenerateCharacter(void) { char nextChar = trie.GetNextCharacter(currentKGram);
double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX); std::cout << nextChar;
return 'z'; if (nextChar == ' ') {
currentKGram.clear();
continue;
}
if (currentKGram.size() < setup.prefixLength) {
currentKGram += nextChar;
} else {
currentKGram = currentKGram.substr(1) + nextChar;
}
}
std::cout << std::endl << "[Generation - Info] Output finished" << std::endl;
} }
void PrintUsage(void) { void PrintUsage(void) {

View File

@ -1,7 +1,6 @@
#ifndef GENERATOR_HPP #ifndef GENERATOR_HPP
#define GENERATOR_HPP #define GENERATOR_HPP
#include <deque>
#include <unordered_map> #include <unordered_map>
#include <string> #include <string>
@ -17,21 +16,21 @@ struct ArgumentList {
bool isOutputSet = false; bool isOutputSet = false;
}; };
class TrieNode { struct TrieNode {
public: public:
TrieNode();
std::unordered_map<char, TrieNode*> children; std::unordered_map<char, TrieNode*> children;
int occurances = 1; int occurrences = 1;
double probability = 0.;
bool isEndOfWord; bool isEndOfWord;
TrieNode() : isEndOfWord(false) {}
}; };
class Trie { struct Trie {
public: public:
Trie() : root(new TrieNode()) {} Trie();
void insert(const std::deque<char>& currentKGram); void Insert(std::string kgram);
bool search(const std::deque<char>& currentKGram) const; void Recalculate(std::unordered_map<char, TrieNode*> children);
private: char GetNextCharacter(const std::string& kgram);
TrieNode* root; TrieNode* root;
}; };
@ -40,11 +39,11 @@ struct Generator {
public: public:
Generator(void) = default; Generator(void) = default;
~Generator(void) = default; ~Generator(void) = default;
void SetArguments(int argc, char* argv[]); void SetArguments(const int argc, char* argv[]);
void ReadFile(void); void Train(void);
void GenerateOutput(void); void Generation(void);
private: private:
char GenerateCharacter(void); std::string currentKGram;
ArgumentList setup; ArgumentList setup;
Trie trie; Trie trie;
}; };

View File

@ -8,7 +8,7 @@ int main(int argc, char* argv[]) {
Generator markovChain; Generator markovChain;
markovChain.SetArguments(argc, argv); markovChain.SetArguments(argc, argv);
InitializeGenerator(); InitializeGenerator();
markovChain.ReadFile(); markovChain.Train();
markovChain.GenerateOutput(); markovChain.Generation();
return 0; return 0;
} }