#include "generator.hpp" #include #include #include #include #include #include std::default_random_engine generator; void InitializeGenerator(void) { generator.seed(std::random_device{}()); } // Returns a newly generated number int GenerateRandomNumber(int generationLimit) { int generatedNumber; std::uniform_int_distribution<> distribution(0, generationLimit - 1); generatedNumber = distribution(generator); return generatedNumber; } TrieNode::TrieNode(void) { isEndOfWord = false; } Trie::Trie(void) { root = new TrieNode(); } void Trie::Insert(std::string kgram) { TrieNode* current = root; for (char ch : kgram) { if (current->children.find(ch) == current->children.end()) { current->children[ch] = new TrieNode(); } else { ++current->occurrences; } Recalculate(current->children); current = current->children[ch]; } current->isEndOfWord = true; } char Trie::GetNextCharacter(const std::string& kgram) { TrieNode* current = root; char nextChar = ' '; for (char ch: kgram) { auto it = current->children.find(ch); if (it == current->children.end()) { return nextChar; } current = it->second; } double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX); double minimum = 0.; for (const auto& i : current->children) { minimum += i.second->probability; if (roll <= minimum) { nextChar = i.first; break; } } return nextChar; } void Trie::Recalculate(std::unordered_map children) { int total = 0; for (auto i : children) { total += i.second->occurrences; } for (auto i : children) { i.second->probability = ((double)i.second->occurrences) / ((double) total); } } void Generator::SetArguments(const int argc, char* argv[]) { std::string tempStr; for (int i = 1; i < argc; i += 2) { tempStr.assign(argv[i]); if (tempStr == "-i") { setup.isFileSet = true; setup.filename.assign(argv[i+1]); } if (tempStr == "-k") { setup.isPrefixSet = true; setup.prefixLength = std::stoi(argv[i+1]) + 1; } if (tempStr == "-n") { setup.isOutputSet = true; setup.outputLength = std::stoi(argv[i+1]) + 1; } if (tempStr == "-h") { PrintUsage(); exit(0); } } if (!setup.isFileSet) { std::cerr << "[Setup - Error] Filename not specified" << std::endl; } if (!setup.prefixLength) { std::cerr << "[Setup - Error] Prefix length not specified" << std::endl; } if (!setup.outputLength) { std::cerr << "[Setup - Error] Output length not specified" << std::endl; } if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); } } void Generator::Train(void) { std::ifstream inputFile(setup.filename); if (!inputFile.is_open()) { std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl; exit(1); } std::stringstream iss; iss << inputFile.rdbuf(); std::vector words(std::istream_iterator{iss}, std::istream_iterator()); std::cout << "[Setup - Info] Begin training" << std::endl; for (const auto& word : words) { if (word.size() < setup.prefixLength) { trie.Insert(word); continue; } else { for (int i = 0; i < word.size() - setup.prefixLength; ++i) { trie.Insert(word.substr(i, setup.prefixLength)); } } } std::cout << "[Setup - Info] Finished training" << std::endl; } void Generator::Generation(void) { std::cout << "[Generation - Info] Output start" << std::endl; for (int i = 0; i < setup.outputLength; ++i) { char nextChar = trie.GetNextCharacter(currentKGram); std::cout << nextChar; if (nextChar == ' ') { currentKGram.clear(); continue; } if (currentKGram.size() < setup.prefixLength) { currentKGram += nextChar; } else { currentKGram = currentKGram.substr(1) + nextChar; } } std::cout << std::endl << "[Generation - Info] Output finished" << std::endl; } void PrintUsage(void) { std::cout << "Usage: markov -i input_file -k prefix_length -n output_length" << std::endl; std::cout << " -i: Direct path to input file for basis" << std::endl; std::cout << " -k: Prefix length for Markov chain" << std::endl; std::cout << " -n: Length of output to be generated (words)" << std::endl; std::cout << " -h: Prints this usage text" << std::endl; }