#pragma once

#include <string>
#include <vector>

class FormatVisitor : public BaseVisitor {
 public:
  FormatVisitor() : ws(0) {}
  void Visit(const BaseNode *node) override {
    node->Visit(this);
  }

  void Visit(const ClassDeclarationNode *node) override {
    std::string str = std::string(ws, ' ') +
        "class " + node->ClassName() + " {";
    formatted.push_back(str);
    int prev_ws = ws;
    ws += 2;
    std::vector<BaseNode *> fields = node->PublicFields();
    if (fields.size() > 0) {
      str = std::string(ws, ' ') + "public:";
      formatted.push_back(str);
      ws += 2;

      for ( auto const &method : fields ) {
        if (MethodDeclarationNode *method_ =
            dynamic_cast<MethodDeclarationNode *>(method)) {
          method_->Visit(this);
        } else if (VarDeclarationNode *var_ =
            dynamic_cast<VarDeclarationNode *>(method)) {
          var_->Visit(this);
        }
      }
      ws -= 2;
    }

    fields = node->ProtectedFields();
    if (fields.size() > 0) {
      formatted.push_back("");
      str = std::string(ws, ' ') + "protected:";
      formatted.push_back(str);
      ws += 2;

      for ( auto const &method : fields ) {
        if (MethodDeclarationNode *method_ =
            dynamic_cast<MethodDeclarationNode *>(method)) {
          method_->Visit(this);
        } else if (VarDeclarationNode *var_ =
            dynamic_cast<VarDeclarationNode *>(method)) {
          var_->Visit(this);
        }
      }
      ws -= 2;
    }
    fields = node->PrivateFields();
    if (fields.size() > 0) {
      formatted.push_back("");
      str = std::string(ws, ' ') + "private:";
      formatted.push_back(str);
      ws += 2;
    }
    for ( auto const &method : fields ) {
      if (MethodDeclarationNode *method_ =
          dynamic_cast<MethodDeclarationNode *>(method)) {
        method_->Visit(this);
      } else if (VarDeclarationNode *var_ =
          dynamic_cast<VarDeclarationNode *>(method)) {
        var_->Visit(this);
      } else if (ClassDeclarationNode *class_ =
          dynamic_cast<ClassDeclarationNode *>(method)) {
        class_->Visit(this);
      }
    }
    formatted.push_back(std::string(prev_ws, ' ') + "};");
  }

  void Visit(const MethodDeclarationNode *node) override {
    std::string str;
    str.clear();
    std::vector<BaseNode *> args = node->Arguments();
    str = std::string(ws, ' ') + node->ReturnTypeName()
        + ' ' + node->MethodName() + '(';
    for ( auto const &arg : args ) {
      VarDeclarationNode *arg_ = dynamic_cast<VarDeclarationNode *>(arg);
      str += arg_->TypeName() + ' ' + arg_->VarName();
      if (&arg != &args.back()) {
        str += ", ";
      }
    }
    str += ");";
    formatted.push_back(str);
  };

  void Visit(const VarDeclarationNode *node) override {
    std::string str = std::string(ws, ' ') +
        node->TypeName() + ' ' + node->VarName() + ';';
    formatted.push_back(str);
  };

  const std::vector<std::string> &GetFormattedCode() const {
    return formatted;
  }

 private:
  std::vector<std::string> formatted;
  int ws;
};