#pragma once

#include <memory>
#include "Test.hpp"

namespace smart_pointer {
    class exception : std::exception {
        using base_class = std::exception;
        using base_class::base_class;
    };

    template<typename T, typename Allocator>
        class SmartPointer {
        ENABLE_CLASS_TESTS;

         public:
             using value_type = T;
             
             explicit SmartPointer(value_type* ptr = nullptr) {
                 if (ptr == nullptr) {
                     this->core = nullptr;
                 } else {
                     this->core = new Core();
                     this->core->ptr = ptr;
                     this->core->count = 1;
                 }
             }

             ~SmartPointer() {
                 if (this->core != nullptr && this->core->count == 1) {
                     delete this->core;
                 }
             }
             
             SmartPointer(const SmartPointer& obj) {
                 this->core = obj.core;
                 
                 if (obj.core != nullptr) {
                     this->core->count++;
                 }
             }
             
             SmartPointer(SmartPointer&& dyingObj) {
                 this->core = dyingObj.core;
                 dyingObj.core = nullptr;
             }
             
             SmartPointer& operator=(const SmartPointer& other) {
                 if (this->core != nullptr && this->core->count == 1) {
                     delete this->core;
                 }

                 this->core = other.core;

                 if (other.core != nullptr) {
                     this->core->count++;
                 }

                 return *this;
             }
             
             SmartPointer& operator=(SmartPointer&& dyingObj) {
                 if (this->core != nullptr) {
                     delete this->core;
                 }

                 this->core = dyingObj.core;
                 dyingObj.core = nullptr;
                 return *this;
             }

             SmartPointer& operator=(value_type* other) {
                 if (this->core != nullptr && this->core->count == 1) {
                     delete this->core;
                 }
             
                 if (other != nullptr) {
                     this->core = new Core();
                     this->core->ptr = other;
                     this->core->count = 1;
                 } else {
                     this->core = nullptr;
                 }

                 return *this;
             }
             
             value_type& operator*() {
                 if (this->core == nullptr) {
                     throw smart_pointer::exception();
                 } else {
                     return *this->core->ptr;
                 }
             }

             const value_type& operator*() const {
                 if (this->core == nullptr) {
                     throw smart_pointer::exception();
                 } else {
                     return *this->core->ptr;
                 }
             }
             
             value_type* operator->() const {
                 if (this->core != nullptr) {
                     return this->core->ptr;
                 } else {
                     return nullptr;
                 }
             }
             
             value_type* get() const {
                 if (this->core != nullptr) {
                     return this->core->ptr;
                 } else {
                     return nullptr;
                 }
             }
             
             operator bool() const {
                 return this->core != nullptr;
             }
             
             template<typename U, typename AnotherAllocator>
             bool operator==(const SmartPointer<U, AnotherAllocator>& ptr) const {
                 return static_cast<void*>(this->get()) == static_cast<void*>
                     (ptr.get());
             }
             
             template<typename U, typename AnotherAllocator>
             bool operator!=(const SmartPointer<U, AnotherAllocator>& ptr) const {
                 return static_cast<void*>(this->get()) != static_cast<void*>
                     (ptr.get());
             }
             
             std::size_t count_owners() const {
                 if (this->core != nullptr) {
                     return this->core->count;
                 } else {
                     return 0;
                 }
             }
        
         private:
             class Core {
              public:
                  size_t count;
                  value_type* ptr;
                 
                  ~Core() {
                      Allocator del;
                      del.deallocate(ptr, sizeof(value_type));
                      this->ptr = nullptr;
                  }
             };
             
             Core* core;
        };
} // namespace smart_pointer