generated from Trianta/cpp-unity-template
commit
4f14763875
@ -2,9 +2,12 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <iterator>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
std::default_random_engine generator;
|
std::default_random_engine generator;
|
||||||
|
|
||||||
void InitializeGenerator(void)
|
void InitializeGenerator(void)
|
||||||
{
|
{
|
||||||
generator.seed(std::random_device{}());
|
generator.seed(std::random_device{}());
|
||||||
@ -15,40 +18,65 @@ int GenerateRandomNumber(int generationLimit)
|
|||||||
{
|
{
|
||||||
int generatedNumber;
|
int generatedNumber;
|
||||||
std::uniform_int_distribution<> distribution(0, generationLimit - 1);
|
std::uniform_int_distribution<> distribution(0, generationLimit - 1);
|
||||||
generatedNumber = distribution(snakeplusplus::generator);
|
generatedNumber = distribution(generator);
|
||||||
return generatedNumber;
|
return generatedNumber;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TrieNode::TrieNode(void) {
|
||||||
|
isEndOfWord = false;
|
||||||
|
}
|
||||||
|
|
||||||
void Trie::insert(const std::deque<char>& currentKGram) {
|
Trie::Trie(void) {
|
||||||
|
root = new TrieNode();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Trie::Insert(std::string kgram) {
|
||||||
TrieNode* current = root;
|
TrieNode* current = root;
|
||||||
|
|
||||||
for (char ch : currentKGram) {
|
for (char ch : kgram) {
|
||||||
if (current->children.find(ch) == current->children.end()) {
|
if (current->children.find(ch) == current->children.end()) {
|
||||||
current->children[ch] = new TrieNode();
|
current->children[ch] = new TrieNode();
|
||||||
} else { ++current->occurances; }
|
} else { ++current->occurrences; }
|
||||||
|
Recalculate(current->children);
|
||||||
current = current->children[ch];
|
current = current->children[ch];
|
||||||
}
|
}
|
||||||
|
|
||||||
current->isEndOfWord = true;
|
current->isEndOfWord = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Trie::search(const std::deque<char>& currentKGram) const {
|
char Trie::GetNextCharacter(const std::string& kgram) {
|
||||||
TrieNode* current = root;
|
TrieNode* current = root;
|
||||||
|
char nextChar = ' ';
|
||||||
for (char ch : currentKGram) {
|
for (char ch: kgram) {
|
||||||
if (current->children.find(ch) == current->children.end()) {
|
auto it = current->children.find(ch);
|
||||||
return false;
|
if (it == current->children.end()) {
|
||||||
|
return nextChar;
|
||||||
|
}
|
||||||
|
current = it->second;
|
||||||
|
}
|
||||||
|
double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
|
||||||
|
double minimum = 0.;
|
||||||
|
for (const auto& i : current->children) {
|
||||||
|
minimum += i.second->probability;
|
||||||
|
if (roll <= minimum) {
|
||||||
|
nextChar = i.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nextChar;
|
||||||
}
|
}
|
||||||
|
|
||||||
current = current->children[ch];
|
void Trie::Recalculate(std::unordered_map<char, TrieNode*> children) {
|
||||||
|
int total = 0;
|
||||||
|
for (auto i : children) {
|
||||||
|
total += i.second->occurrences;
|
||||||
|
}
|
||||||
|
for (auto i : children) {
|
||||||
|
i.second->probability = ((double)i.second->occurrences) / ((double) total);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return current->isEndOfWord;
|
void Generator::SetArguments(const int argc, char* argv[]) {
|
||||||
}
|
|
||||||
|
|
||||||
void Generator::SetArguments(int argc, char* argv[]) {
|
|
||||||
std::string tempStr;
|
std::string tempStr;
|
||||||
for (int i = 1; i < argc; i += 2) {
|
for (int i = 1; i < argc; i += 2) {
|
||||||
tempStr.assign(argv[i]);
|
tempStr.assign(argv[i]);
|
||||||
@ -58,11 +86,11 @@ void Generator::SetArguments(int argc, char* argv[]) {
|
|||||||
}
|
}
|
||||||
if (tempStr == "-k") {
|
if (tempStr == "-k") {
|
||||||
setup.isPrefixSet = true;
|
setup.isPrefixSet = true;
|
||||||
setup.prefixLength = std::stoi(argv[i+1]);
|
setup.prefixLength = std::stoi(argv[i+1]) + 1;
|
||||||
}
|
}
|
||||||
if (tempStr == "-n") {
|
if (tempStr == "-n") {
|
||||||
setup.isOutputSet = true;
|
setup.isOutputSet = true;
|
||||||
setup.outputLength = std::stoi(argv[i+1]);
|
setup.outputLength = std::stoi(argv[i+1]) + 1;
|
||||||
}
|
}
|
||||||
if (tempStr == "-h") {
|
if (tempStr == "-h") {
|
||||||
PrintUsage();
|
PrintUsage();
|
||||||
@ -75,35 +103,47 @@ void Generator::SetArguments(int argc, char* argv[]) {
|
|||||||
if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); }
|
if (!setup.isFileSet || !setup.isPrefixSet || !setup.isOutputSet) { PrintUsage(); }
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::ReadFile(void) {
|
void Generator::Train(void) {
|
||||||
std::ifstream inputFile(setup.filename);
|
std::ifstream inputFile(setup.filename);
|
||||||
if (!inputFile.is_open()) {
|
if (!inputFile.is_open()) {
|
||||||
std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl;
|
std::cerr << "[ReadFile - Error] Could not open file: " << setup.filename << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
std::deque<char> currentKGram;
|
std::stringstream iss;
|
||||||
char tempChar;
|
iss << inputFile.rdbuf();
|
||||||
// Read in first k-gram
|
std::vector<std::string> words(std::istream_iterator<std::string>{iss},
|
||||||
{
|
std::istream_iterator<std::string>());
|
||||||
std::string initializeKGram;
|
|
||||||
inputFile.get(&initializeKGram[0], setup.prefixLength);
|
std::cout << "[Setup - Info] Begin training" << std::endl;
|
||||||
for (char ch : initializeKGram) { currentKGram.emplace_back(ch); }
|
for (const auto& word : words) {
|
||||||
trie.insert(currentKGram);
|
if (word.size() < setup.prefixLength) {
|
||||||
|
trie.Insert(word);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < word.size() - setup.prefixLength; ++i) {
|
||||||
|
trie.Insert(word.substr(i, setup.prefixLength));
|
||||||
}
|
}
|
||||||
// Read rest of file
|
|
||||||
while (inputFile.get(tempChar)) {
|
|
||||||
currentKGram.emplace_back(tempChar);
|
|
||||||
if (currentKGram.size() > setup.prefixLength) { currentKGram.pop_front(); }
|
|
||||||
trie.insert(currentKGram);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::cout << "[Setup - Info] Finished training" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
void Generator::GenerateOutput(void) {
|
void Generator::Generation(void) {
|
||||||
|
std::cout << "[Generation - Info] Output start" << std::endl;
|
||||||
|
for (int i = 0; i < setup.outputLength; ++i) {
|
||||||
|
char nextChar = trie.GetNextCharacter(currentKGram);
|
||||||
|
std::cout << nextChar;
|
||||||
|
if (nextChar == ' ') {
|
||||||
|
currentKGram.clear();
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
if (currentKGram.size() < setup.prefixLength) {
|
||||||
char Generator::GenerateCharacter(void) {
|
currentKGram += nextChar;
|
||||||
double roll = ((double) GenerateRandomNumber(RAND_MAX)) / ((double) RAND_MAX);
|
} else {
|
||||||
return 'z';
|
currentKGram = currentKGram.substr(1) + nextChar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << std::endl << "[Generation - Info] Output finished" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintUsage(void) {
|
void PrintUsage(void) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#ifndef GENERATOR_HPP
|
#ifndef GENERATOR_HPP
|
||||||
#define GENERATOR_HPP
|
#define GENERATOR_HPP
|
||||||
|
|
||||||
#include <deque>
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@ -17,21 +16,21 @@ struct ArgumentList {
|
|||||||
bool isOutputSet = false;
|
bool isOutputSet = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TrieNode {
|
struct TrieNode {
|
||||||
public:
|
public:
|
||||||
|
TrieNode();
|
||||||
std::unordered_map<char, TrieNode*> children;
|
std::unordered_map<char, TrieNode*> children;
|
||||||
int occurances = 1;
|
int occurrences = 1;
|
||||||
|
double probability = 0.;
|
||||||
bool isEndOfWord;
|
bool isEndOfWord;
|
||||||
|
|
||||||
TrieNode() : isEndOfWord(false) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class Trie {
|
struct Trie {
|
||||||
public:
|
public:
|
||||||
Trie() : root(new TrieNode()) {}
|
Trie();
|
||||||
void insert(const std::deque<char>& currentKGram);
|
void Insert(std::string kgram);
|
||||||
bool search(const std::deque<char>& currentKGram) const;
|
void Recalculate(std::unordered_map<char, TrieNode*> children);
|
||||||
private:
|
char GetNextCharacter(const std::string& kgram);
|
||||||
TrieNode* root;
|
TrieNode* root;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -40,11 +39,11 @@ struct Generator {
|
|||||||
public:
|
public:
|
||||||
Generator(void) = default;
|
Generator(void) = default;
|
||||||
~Generator(void) = default;
|
~Generator(void) = default;
|
||||||
void SetArguments(int argc, char* argv[]);
|
void SetArguments(const int argc, char* argv[]);
|
||||||
void ReadFile(void);
|
void Train(void);
|
||||||
void GenerateOutput(void);
|
void Generation(void);
|
||||||
private:
|
private:
|
||||||
char GenerateCharacter(void);
|
std::string currentKGram;
|
||||||
ArgumentList setup;
|
ArgumentList setup;
|
||||||
Trie trie;
|
Trie trie;
|
||||||
};
|
};
|
||||||
|
@ -8,7 +8,7 @@ int main(int argc, char* argv[]) {
|
|||||||
Generator markovChain;
|
Generator markovChain;
|
||||||
markovChain.SetArguments(argc, argv);
|
markovChain.SetArguments(argc, argv);
|
||||||
InitializeGenerator();
|
InitializeGenerator();
|
||||||
markovChain.ReadFile();
|
markovChain.Train();
|
||||||
markovChain.GenerateOutput();
|
markovChain.Generation();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user