From c19461fac2cf18213462bb648546ba0077a85343 Mon Sep 17 00:00:00 2001
From: Trianta <56975502+Trimutex@users.noreply.github.com>
Date: Fri, 8 Dec 2023 22:45:30 -0600
Subject: [PATCH] Fixed generator crashes and improved generations. k=0 is only
 putting spaces though...

---
 src/generator.cpp | 112 +++++++++++++++++++++++++++++++---------------
 src/generator.hpp |  27 ++++++-----
 src/main.cpp      |   4 +-
 3 files changed, 91 insertions(+), 52 deletions(-)

diff --git a/src/generator.cpp b/src/generator.cpp
index 93b6a42..74ef1d5 100644
--- a/src/generator.cpp
+++ b/src/generator.cpp
@@ -2,9 +2,12 @@
 #include <cstdlib>
 #include <fstream>
 #include <iostream>
+#include <iterator>
 #include <random>
+#include <sstream>
 
 std::default_random_engine generator;
+
 void InitializeGenerator(void)
 {
     generator.seed(std::random_device{}());
@@ -15,40 +18,65 @@ int GenerateRandomNumber(int generationLimit)
 {
     int generatedNumber;
     std::uniform_int_distribution<> distribution(0, generationLimit - 1);
-    generatedNumber = distribution(snakeplusplus::generator);
+    generatedNumber = distribution(generator);
     return generatedNumber;
 }
 
+TrieNode::TrieNode(void) {
+    isEndOfWord = false;
+}
 
-void Trie::insert(const std::deque<char>& currentKGram) {
+Trie::Trie(void) {
+    root = new TrieNode();
+}
+
+void Trie::Insert(std::string kgram) {
     TrieNode* current = root;
 
-    for (char ch : currentKGram) {
+    for (char ch : kgram) {
         if (current->children.find(ch) == current->children.end()) {
             current->children[ch] = new TrieNode();
-        } else { ++current->occurances; }
-
+        } else { ++current->occurrences; }
+        Recalculate(current->children);
         current = current->children[ch];
     }
 
     current->isEndOfWord = true;
 }
 
-bool Trie::search(const std::deque<char>& currentKGram) const {
+char Trie::GetNextCharacter(const std::string& kgram) {
     TrieNode* current = root;
-
-    for (char ch : currentKGram) {
-        if (current->children.find(ch) == current->children.end()) {
-            return false;
+    char nextChar = ' ';
+    for (char ch: kgram) {
+        auto it = current->children.find(ch);
+        if (it == current->children.end()) {
+            return nextChar;
         }
-
-        current = current->children[ch];
+        current = it->second;
     }
-
-    return current->isEndOfWord;
+    double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
+    double minimum = 0.;
+    for (const auto& i : current->children) {
+        minimum += i.second->probability;
+        if (roll <= minimum) { 
+            nextChar = i.first;
+            break;
+        }
+    }
+    return nextChar;
 }
 
-void Generator::SetArguments(int argc, char* argv[]) {
+void Trie::Recalculate(std::unordered_map<char, TrieNode*> children) {
+    int total = 0;
+    for (auto i : children) {
+        total += i.second->occurrences;
+    }
+    for (auto i : children) {
+        i.second->probability = ((double)i.second->occurrences) / ((double) total);
+    }
+}
+
+void Generator::SetArguments(const int argc, char* argv[]) {
     std::string tempStr;
     for (int i = 1; i < argc; i += 2) {
         tempStr.assign(argv[i]);
@@ -75,35 +103,47 @@ void Generator::SetArguments(int argc, char* argv[]) {
     if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); }
 }
 
-void Generator::ReadFile(void) {
+void Generator::Train(void) {
     std::ifstream inputFile(setup.filename);
     if (!inputFile.is_open()) {
         std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl;
         exit(1);
     }
-    std::deque<char> currentKGram;
-    char tempChar;
-    // Read in first k-gram
-    {
-        std::string initializeKGram;
-        inputFile.get(&initializeKGram[0], setup.prefixLength);
-        for (char ch : initializeKGram) { currentKGram.emplace_back(ch); }
-        trie.insert(currentKGram);
-    }
-    // Read rest of file
-    while (inputFile.get(tempChar)) {
-        currentKGram.emplace_back(tempChar);
-        if (currentKGram.size() > setup.prefixLength) { currentKGram.pop_front(); }
-        trie.insert(currentKGram);
+    std::stringstream iss;
+    iss << inputFile.rdbuf();
+    std::vector<std::string> words(std::istream_iterator<std::string>{iss},
+                                   std::istream_iterator<std::string>());
+
+    std::cout << "[Setup - Info] Begin training" << std::endl;
+    for (const auto& word : words) {
+        if (word.size() < setup.prefixLength) {
+            trie.Insert(word);
+            continue;
+        } else {
+            for (int i = 0; i < word.size() - setup.prefixLength; ++i) {
+                trie.Insert(word.substr(i, setup.prefixLength));
+            }
+        }
     }
+    std::cout << "[Setup - Info] Finished training" << std::endl;
 }
 
-void Generator::GenerateOutput(void) {
-}
-
-char Generator::GenerateCharacter(void) {
-    double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
-    return 'z';
+void Generator::Generation(void) {
+    std::cout << "[Generation - Info] Output start" << std::endl;
+    for (int i = 0; i < setup.outputLength; ++i) {
+        char nextChar = trie.GetNextCharacter(currentKGram);
+        std::cout << nextChar;
+        if (nextChar == ' ') { 
+            currentKGram.clear();
+            continue;
+        }
+        if (currentKGram.size() < setup.prefixLength) {
+            currentKGram += nextChar;
+        } else {
+            currentKGram = currentKGram.substr(1) + nextChar;
+        }
+    }
+    std::cout << std::endl << "[Generation - Info] Output finished" << std::endl;
 }
 
 void PrintUsage(void) {
diff --git a/src/generator.hpp b/src/generator.hpp
index 12bb53d..8ea27b4 100644
--- a/src/generator.hpp
+++ b/src/generator.hpp
@@ -1,7 +1,6 @@
 #ifndef GENERATOR_HPP
 #define GENERATOR_HPP
 
-#include <deque>
 #include <unordered_map>
 #include <string>
 
@@ -17,21 +16,21 @@ struct ArgumentList {
     bool isOutputSet = false;
 };
 
-class TrieNode {
+struct TrieNode {
 public:
+    TrieNode();
     std::unordered_map<char, TrieNode*> children;
-    int occurances = 1;
+    int occurrences = 1;
+    double probability = 0.;
     bool isEndOfWord;
-
-    TrieNode() : isEndOfWord(false) {}
 };
 
-class Trie {
+struct Trie {
 public:
-    Trie() : root(new TrieNode()) {}
-    void insert(const std::deque<char>& currentKGram);
-    bool search(const std::deque<char>& currentKGram) const;
-private:
+    Trie();
+    void Insert(std::string kgram);
+    void Recalculate(std::unordered_map<char, TrieNode*> children);
+    char GetNextCharacter(const std::string& kgram);
     TrieNode* root;
 };
 
@@ -40,11 +39,11 @@ struct Generator {
 public:
     Generator(void) = default;
     ~Generator(void) = default;
-    void SetArguments(int argc, char* argv[]);
-    void ReadFile(void);
-    void GenerateOutput(void);
+    void SetArguments(const int argc, char* argv[]);
+    void Train(void);
+    void Generation(void);
 private:
-    char GenerateCharacter(void);
+    std::string currentKGram;
     ArgumentList setup;
     Trie trie;
 };
diff --git a/src/main.cpp b/src/main.cpp
index c356a50..37ac5ba 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -8,7 +8,7 @@ int main(int argc, char* argv[]) {
     Generator markovChain;
     markovChain.SetArguments(argc, argv);
     InitializeGenerator();
-    markovChain.ReadFile();
-    markovChain.GenerateOutput();
+    markovChain.Train();
+    markovChain.Generation();
     return 0;
 }