markov-generator/src/generator.cpp

156 lines
4.8 KiB
C++
Raw Normal View History

#include "generator.hpp"
2023-12-08 18:19:14 -06:00
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <iterator>
2023-12-08 18:19:14 -06:00
#include <random>
#include <sstream>
2023-12-08 18:19:14 -06:00
std::default_random_engine generator;
2023-12-08 18:19:14 -06:00
void InitializeGenerator(void)
{
generator.seed(std::random_device{}());
}
// Returns a newly generated number
int GenerateRandomNumber(int generationLimit)
{
int generatedNumber;
std::uniform_int_distribution<> distribution(0, generationLimit - 1);
generatedNumber = distribution(generator);
2023-12-08 18:19:14 -06:00
return generatedNumber;
}
TrieNode::TrieNode(void) {
isEndOfWord = false;
}
Trie::Trie(void) {
root = new TrieNode();
}
void Trie::Insert(std::string kgram) {
TrieNode* current = root;
for (char ch : kgram) {
if (current->children.find(ch) == current->children.end()) {
current->children[ch] = new TrieNode();
} else { ++current->occurrences; }
Recalculate(current->children);
current = current->children[ch];
}
current->isEndOfWord = true;
}
char Trie::GetNextCharacter(const std::string& kgram) {
TrieNode* current = root;
char nextChar = ' ';
for (char ch: kgram) {
auto it = current->children.find(ch);
if (it == current->children.end()) {
return nextChar;
}
current = it->second;
}
double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
double minimum = 0.;
for (const auto& i : current->children) {
minimum += i.second->probability;
if (roll <= minimum) {
nextChar = i.first;
break;
}
}
return nextChar;
}
void Trie::Recalculate(std::unordered_map<char, TrieNode*> children) {
int total = 0;
for (auto i : children) {
total += i.second->occurrences;
}
for (auto i : children) {
i.second->probability = ((double)i.second->occurrences) / ((double) total);
}
}
void Generator::SetArguments(const int argc, char* argv[]) {
std::string tempStr;
2023-12-07 23:58:54 -06:00
for (int i = 1; i < argc; i += 2) {
tempStr.assign(argv[i]);
if (tempStr == "-i") {
setup.isFileSet = true;
2023-12-07 23:58:54 -06:00
setup.filename.assign(argv[i+1]);
}
if (tempStr == "-k") {
setup.isPrefixSet = true;
2023-12-07 23:58:54 -06:00
setup.prefixLength = std::stoi(argv[i+1]);
}
if (tempStr == "-n") {
setup.isOutputSet = true;
2023-12-07 23:58:54 -06:00
setup.outputLength = std::stoi(argv[i+1]);
}
if (tempStr == "-h") {
PrintUsage();
exit(0);
}
}
if (!setup.isFileSet) { std::cerr << "[Setup - Error] Filename not specified" << std::endl; }
if (!setup.prefixLength) { std::cerr << "[Setup - Error] Prefix length not specified" << std::endl; }
if (!setup.outputLength) { std::cerr << "[Setup - Error] Output length not specified" << std::endl; }
if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); }
}
void Generator::Train(void) {
std::ifstream inputFile(setup.filename);
if (!inputFile.is_open()) {
std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl;
exit(1);
}
std::stringstream iss;
iss << inputFile.rdbuf();
std::vector<std::string> words(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
std::cout << "[Setup - Info] Begin training" << std::endl;
for (const auto& word : words) {
if (word.size() < setup.prefixLength) {
trie.Insert(word);
continue;
} else {
for (int i = 0; i < word.size() - setup.prefixLength; ++i) {
trie.Insert(word.substr(i, setup.prefixLength));
}
}
}
std::cout << "[Setup - Info] Finished training" << std::endl;
}
void Generator::Generation(void) {
std::cout << "[Generation - Info] Output start" << std::endl;
for (int i = 0; i < setup.outputLength; ++i) {
char nextChar = trie.GetNextCharacter(currentKGram);
std::cout << nextChar;
if (nextChar == ' ') {
currentKGram.clear();
continue;
}
if (currentKGram.size() < setup.prefixLength) {
currentKGram += nextChar;
} else {
currentKGram = currentKGram.substr(1) + nextChar;
}
}
std::cout << std::endl << "[Generation - Info] Output finished" << std::endl;
}
void PrintUsage(void) {
std::cout << "Usage: markov -i input_file -k prefix_length -n output_length" << std::endl;
std::cout << " -i: Direct path to input file for basis" << std::endl;
std::cout << " -k: Prefix length for Markov chain" << std::endl;
std::cout << " -n: Length of output to be generated (words)" << std::endl;
std::cout << " -h: Prints this usage text" << std::endl;
}