From c19461fac2cf18213462bb648546ba0077a85343 Mon Sep 17 00:00:00 2001 From: Trianta <56975502+Trimutex@users.noreply.github.com> Date: Fri, 8 Dec 2023 22:45:30 -0600 Subject: [PATCH 1/2] Fixed generator crashes and improved generations. k=0 is only putting spaces though... --- src/generator.cpp | 112 +++++++++++++++++++++++++++++++--------------- src/generator.hpp | 27 ++++++----- src/main.cpp | 4 +- 3 files changed, 91 insertions(+), 52 deletions(-) diff --git a/src/generator.cpp b/src/generator.cpp index 93b6a42..74ef1d5 100644 --- a/src/generator.cpp +++ b/src/generator.cpp @@ -2,9 +2,12 @@ #include #include #include +#include #include +#include std::default_random_engine generator; + void InitializeGenerator(void) { generator.seed(std::random_device{}()); @@ -15,40 +18,65 @@ int GenerateRandomNumber(int generationLimit) { int generatedNumber; std::uniform_int_distribution<> distribution(0, generationLimit - 1); - generatedNumber = distribution(snakeplusplus::generator); + generatedNumber = distribution(generator); return generatedNumber; } +TrieNode::TrieNode(void) { + isEndOfWord = false; +} -void Trie::insert(const std::deque& currentKGram) { +Trie::Trie(void) { + root = new TrieNode(); +} + +void Trie::Insert(std::string kgram) { TrieNode* current = root; - for (char ch : currentKGram) { + for (char ch : kgram) { if (current->children.find(ch) == current->children.end()) { current->children[ch] = new TrieNode(); - } else { ++current->occurances; } - + } else { ++current->occurrences; } + Recalculate(current->children); current = current->children[ch]; } current->isEndOfWord = true; } -bool Trie::search(const std::deque& currentKGram) const { +char Trie::GetNextCharacter(const std::string& kgram) { TrieNode* current = root; - - for (char ch : currentKGram) { - if (current->children.find(ch) == current->children.end()) { - return false; + char nextChar = ' '; + for (char ch: kgram) { + auto it = current->children.find(ch); + if (it == current->children.end()) { + return nextChar; } - - current = current->children[ch]; + current = it->second; } - - return current->isEndOfWord; + double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX); + double minimum = 0.; + for (const auto& i : current->children) { + minimum += i.second->probability; + if (roll <= minimum) { + nextChar = i.first; + break; + } + } + return nextChar; } -void Generator::SetArguments(int argc, char* argv[]) { +void Trie::Recalculate(std::unordered_map children) { + int total = 0; + for (auto i : children) { + total += i.second->occurrences; + } + for (auto i : children) { + i.second->probability = ((double)i.second->occurrences) / ((double) total); + } +} + +void Generator::SetArguments(const int argc, char* argv[]) { std::string tempStr; for (int i = 1; i < argc; i += 2) { tempStr.assign(argv[i]); @@ -75,35 +103,47 @@ void Generator::SetArguments(int argc, char* argv[]) { if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); } } -void Generator::ReadFile(void) { +void Generator::Train(void) { std::ifstream inputFile(setup.filename); if (!inputFile.is_open()) { std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl; exit(1); } - std::deque currentKGram; - char tempChar; - // Read in first k-gram - { - std::string initializeKGram; - inputFile.get(&initializeKGram[0], setup.prefixLength); - for (char ch : initializeKGram) { currentKGram.emplace_back(ch); } - trie.insert(currentKGram); - } - // Read rest of file - while (inputFile.get(tempChar)) { - currentKGram.emplace_back(tempChar); - if (currentKGram.size() > setup.prefixLength) { currentKGram.pop_front(); } - trie.insert(currentKGram); + std::stringstream iss; + iss << inputFile.rdbuf(); + std::vector words(std::istream_iterator{iss}, + std::istream_iterator()); + + std::cout << "[Setup - Info] Begin training" << std::endl; + for (const auto& word : words) { + if (word.size() < setup.prefixLength) { + trie.Insert(word); + continue; + } else { + for (int i = 0; i < word.size() - setup.prefixLength; ++i) { + trie.Insert(word.substr(i, setup.prefixLength)); + } + } } + std::cout << "[Setup - Info] Finished training" << std::endl; } -void Generator::GenerateOutput(void) { -} - -char Generator::GenerateCharacter(void) { - double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX); - return 'z'; +void Generator::Generation(void) { + std::cout << "[Generation - Info] Output start" << std::endl; + for (int i = 0; i < setup.outputLength; ++i) { + char nextChar = trie.GetNextCharacter(currentKGram); + std::cout << nextChar; + if (nextChar == ' ') { + currentKGram.clear(); + continue; + } + if (currentKGram.size() < setup.prefixLength) { + currentKGram += nextChar; + } else { + currentKGram = currentKGram.substr(1) + nextChar; + } + } + std::cout << std::endl << "[Generation - Info] Output finished" << std::endl; } void PrintUsage(void) { diff --git a/src/generator.hpp b/src/generator.hpp index 12bb53d..8ea27b4 100644 --- a/src/generator.hpp +++ b/src/generator.hpp @@ -1,7 +1,6 @@ #ifndef GENERATOR_HPP #define GENERATOR_HPP -#include #include #include @@ -17,21 +16,21 @@ struct ArgumentList { bool isOutputSet = false; }; -class TrieNode { +struct TrieNode { public: + TrieNode(); std::unordered_map children; - int occurances = 1; + int occurrences = 1; + double probability = 0.; bool isEndOfWord; - - TrieNode() : isEndOfWord(false) {} }; -class Trie { +struct Trie { public: - Trie() : root(new TrieNode()) {} - void insert(const std::deque& currentKGram); - bool search(const std::deque& currentKGram) const; -private: + Trie(); + void Insert(std::string kgram); + void Recalculate(std::unordered_map children); + char GetNextCharacter(const std::string& kgram); TrieNode* root; }; @@ -40,11 +39,11 @@ struct Generator { public: Generator(void) = default; ~Generator(void) = default; - void SetArguments(int argc, char* argv[]); - void ReadFile(void); - void GenerateOutput(void); + void SetArguments(const int argc, char* argv[]); + void Train(void); + void Generation(void); private: - char GenerateCharacter(void); + std::string currentKGram; ArgumentList setup; Trie trie; }; diff --git a/src/main.cpp b/src/main.cpp index c356a50..37ac5ba 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -8,7 +8,7 @@ int main(int argc, char* argv[]) { Generator markovChain; markovChain.SetArguments(argc, argv); InitializeGenerator(); - markovChain.ReadFile(); - markovChain.GenerateOutput(); + markovChain.Train(); + markovChain.Generation(); return 0; } -- 2.45.3 From be7ef06a2d434d8400da4a9619abbf3d45eb321a Mon Sep 17 00:00:00 2001 From: Trianta <56975502+Trimutex@users.noreply.github.com> Date: Fri, 8 Dec 2023 22:51:44 -0600 Subject: [PATCH 2/2] Fixed k=0 issue --- src/generator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/generator.cpp b/src/generator.cpp index 74ef1d5..2aaa27a 100644 --- a/src/generator.cpp +++ b/src/generator.cpp @@ -86,11 +86,11 @@ void Generator::SetArguments(const int argc, char* argv[]) { } if (tempStr == "-k") { setup.isPrefixSet = true; - setup.prefixLength = std::stoi(argv[i+1]); + setup.prefixLength = std::stoi(argv[i+1]) + 1; } if (tempStr == "-n") { setup.isOutputSet = true; - setup.outputLength = std::stoi(argv[i+1]); + setup.outputLength = std::stoi(argv[i+1]) + 1; } if (tempStr == "-h") { PrintUsage(); -- 2.45.3