#pragma once
#include <string>
#include <vector>
class FormatVisitor: public BaseVisitor {
private:
std::vector<std::string> formatted = {""};
public:
void Visit(const BaseNode* node) override {
node->Visit(this);
}
void Visit(const ClassDeclarationNode* node) override {
int32_t white_spaces = formatted[formatted.size() - 1].length();
formatted[formatted.size() - 1] += "class " + node->ClassName() + " {";
// public
if (!node->PublicFields().empty()) {
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " public:";
for (auto item : node->PublicFields()) {
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " ";
Visit(item);
}
}
// protected
if (!node->ProtectedFields().empty()) {
formatted.emplace_back("");
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " protected:";
for (auto item : node->ProtectedFields()) {
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " ";
Visit(item);
}
}
// private
if (!node->PrivateFields().empty()) {
formatted.emplace_back("");
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " private:";
for (auto item : node->PrivateFields()) {
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += " ";
Visit(item);
}
}
formatted.emplace_back("");
for (int32_t space = 0; space < white_spaces; ++space)
formatted[formatted.size() - 1] += " ";
formatted[formatted.size() - 1] += "};";
}
void Visit(const VarDeclarationNode* node) override {
int32_t len = formatted[formatted.size() - 1].length();
if (formatted[formatted.size() - 1][len - 2] == ';') {
formatted[formatted.size() - 1][len - 2] = ',';
}
std::string arg = node->TypeName() + " " + node->VarName() + ";";
formatted[formatted.size() - 1] += arg;
}
void Visit(const MethodDeclarationNode* node) override {
std::string s = node->ReturnTypeName() + " ";
s += node->MethodName();
s += "(";
formatted[formatted.size() - 1] += s;
std::vector <BaseNode*> args = node->Arguments();
for (size_t i = 0; i < args.size(); ++i) {
Visit(args[i]);
if (i < args.size() - 1) {
formatted[formatted.size() - 1] += " ";
}
}
int32_t len = formatted[formatted.size() - 1].length();
if (formatted[formatted.size() - 1][len - 1] == ';') {
formatted[formatted.size() - 1][len - 1] = ')';
formatted[formatted.size() - 1] += ";";
} else {
formatted[formatted.size() - 1] += ");";
}
}
const std::vector<std::string>& GetFormattedCode() const {
return formatted;
}
};