#pragma once

#include <memory>
#include "Test.hpp"

namespace smart_pointer {
class exception : std::exception {
    using base_class = std::exception;
    using base_class::base_class;
};

template <typename T, typename Allocator>
class SmartPointer {
    ENABLE_CLASS_TESTS;

 public:
    using value_type = T;

    SmartPointer(value_type* ptr_ = nullptr) {
        if (ptr_ != nullptr) {
            core = new Core(ptr_);
            core->owners++;
        } else {
            core = nullptr;
        }
    }

    SmartPointer(const SmartPointer& other) {
        core = other.core;
        if (core != nullptr) core->owners++;
    }

    SmartPointer(SmartPointer&& other) : core(other.core) {
        other.core = nullptr;
    }

    SmartPointer& operator=(const SmartPointer& other) {
        if (core != nullptr) {
            if (core->owners > 1)
                core->owners--;
            else
                delete core;
        }
        core = other.core;
        if (core != nullptr) core->owners++;
        return *this;
    }

    SmartPointer& operator=(SmartPointer&& other) {
        other.core = nullptr;
        return *this;
    }

    SmartPointer& operator=(value_type* other) {
        if (core != nullptr) {
            if (core->owners > 1)
                core->owners--;
            else
                delete core;
        }
        if (other != nullptr)
            core = new Core(other);
        else
            core = nullptr;
        return *this;
    }

    ~SmartPointer() {
        if (core != nullptr) core = nullptr;
    }

    value_type& operator*() {
        if (core == nullptr) throw smart_pointer::exception();
        return *core->ptr;
    }

    const value_type& operator*() const {
        if (core == nullptr) throw smart_pointer::exception();
        return *core->ptr;
    }

    value_type* operator->() const {
        if (core != nullptr) return core->ptr;
        return nullptr;
    }

    value_type* get() const {
        if (core != nullptr) return core->ptr;
        return nullptr;
    }

    operator bool() const {
      if (core == nullptr) return false;
      return true;
    }

    template <typename U, typename AnotherAllocator>
    bool operator==(const SmartPointer<U, AnotherAllocator>& other) const {
        if (core != nullptr)
          return (core->ptr == reinterpret_cast<value_type*>(other.get()));
        else
          return (other.get() == nullptr);
    }

    template <typename U, typename AnotherAllocator>
    bool operator!=(const SmartPointer<U, AnotherAllocator>& other) const {
        if (*this == other) return false;
        return true;
    }

    std::size_t count_owners() const {
        if (core != nullptr) return core->owners;
        return 0;
    }

 private:
    class Core {
     public:
        explicit Core(value_type* ptr_) : ptr(ptr_) {}
        size_t owners = 0;
        value_type* ptr;
    };
    Core* core;
};
}  // namespace smart_pointer