#include <vector>
#include <string>
class FormatVisitor: public BaseVisitor {
public:
void Visit(const BaseNode* node) override {
node->Visit(this);
}
void Visit(const ClassDeclarationNode* node) override {
code_.push_back("class " + node->ClassName() + " {");
code_.push_back(" public:");
for (auto pub : node->PublicFields()) {
code_.push_back(" ");
this->Visit(pub);
}
if (node->ProtectedFields().size()) {
code_.push_back("");
code_.push_back(" protected:");
for (auto prot : node->ProtectedFields()) {
code_.push_back(" ");
this->Visit(prot);
}
}
if (node->PrivateFields().size()) {
code_.push_back("");
code_.push_back(" private:");
for (auto priv : node->PrivateFields()) {
if (dynamic_cast<ClassDeclarationNode*>(priv) != nullptr) {
FormatVisitor tempVis;
tempVis.Visit(priv);
std::vector<std::string> tempStr =
tempVis.GetFormattedCode();
for (auto str : tempStr) {
code_.push_back(" " + str);
}
} else {
code_.push_back(" ");
this->Visit(priv);
}
}
}
code_.push_back("};");
}
void Visit(const VarDeclarationNode* node) override {
code_[code_.size() - 1] += node->TypeName() + " " +
node->VarName() + ";";
}
void Visit(const MethodDeclarationNode* node) override {
code_[code_.size() - 1] += node->ReturnTypeName() + " " +
node->MethodName() + "(";
if (node->Arguments().size()) {
for (size_t i = 0; i < node->Arguments().size() - 1; ++i) {
this->Visit(node->Arguments()[i]);
code_[code_.size() - 1].pop_back();
code_[code_.size() - 1] += ", ";
}
this->Visit(node->Arguments()[node->Arguments().size() - 1]);
}
code_[code_.size() - 1] += ");";
}
const std::vector<std::string>& GetFormattedCode() const {
return code_;
}
private:
std::vector<std::string> code_;
};