#pragma once

#include <memory>
#include <type_traits>
#include "Test.hpp"

namespace smart_pointer {
    class exception : std::exception {
        using base_class = std::exception;
        using base_class::base_class;
    };

    template<typename T, typename Allocator>
    class SmartPointer {
        ENABLE_CLASS_TESTS;
     public:
        using value_type = T;

        SmartPointer(value_type *value = nullptr) {
            if (value != nullptr)
                core = new Core(value);
            else
                core = nullptr;
        }

        // copy constructor
        SmartPointer(const SmartPointer &rhs) {
            core = rhs.core;

            if (core != nullptr)
                core->Increment();
        }

        // move constructor
        SmartPointer(SmartPointer &&rhs) {
            core = rhs.core;
            rhs.core = nullptr;
        }

        // copy assigment
        SmartPointer &operator=(const SmartPointer &rhs) {
            SmartPointer& lhs = *this;
            lhs.core = rhs.core;

            if (lhs.core != nullptr)
                lhs.core->Increment();

            return *this;
        }

        // move assigment
        SmartPointer &operator=(SmartPointer &&rhs) {
            SmartPointer& lhs = *this;
            lhs.core = rhs.core;
            rhs.core = nullptr;

            return *this;
        }

        //
        SmartPointer &operator=(value_type *rhs) {
            SmartPointer& lhs = *this;

            if (lhs.core != nullptr)
                lhs.core->Decrement();

            lhs = SmartPointer(rhs);

            return *this;
        }

        ~SmartPointer() {
            if (core != nullptr) {
                core->Decrement();

                if (core->CountOwners() == 0)
                    delete core;
            }
        }

        // return reference to the object of class/type T
        // if SmartPointer contains nullptr throw `SmartPointer::exception`
        value_type &operator*() {
            if (!*this)
                throw smart_pointer::exception();

            return *(core->Value());
        }

        const value_type &operator*() const {
            if (!*this)
                throw smart_pointer::exception();

            return *(core->Value());
        }

        // return pointer to the object of class/type T
        value_type *operator->() const {
            if (*this)
                return core->Value();

            return nullptr;
        }

        value_type *get() const {
            if (*this)
                return core->Value();

            return nullptr;
        }

        // if pointer == nullptr => return false
        operator bool() const {
            return core != nullptr;
        }

        // if pointers points to the same address or both null => true
        template<typename U, typename AnotherAllocator>
        bool operator==(const SmartPointer<U, AnotherAllocator> & rhs) const {
            const SmartPointer& lhs = *this;

            if (!lhs && !rhs)
                return true;

            if (!std::is_same<T, U>::value)
                return false;

            return reinterpret_cast<U*>(lhs.get()) == rhs.get();
        }

        // if pointers points to the same address or both null => false
        template<typename U, typename AnotherAllocator>
        bool operator!=(const SmartPointer<U, AnotherAllocator> & rhs) const {
            const SmartPointer& lhs = *this;

            if (!lhs && !rhs)
                return false;

            if (!std::is_same<T, U>::value)
                return true;

            return reinterpret_cast<U*>(lhs.get()) != rhs.get();
        }

        // if smart pointer contains non-nullptr => return count owners
        // if smart pointer contains nullptr => return 0
        std::size_t count_owners() const {
            if (core != nullptr)
                return core->CountOwners();

            return 0;
        }

    private:
        class Core {
         public:
            explicit Core(value_type* value) : value_(value) {
                Increment();
            }

            value_type *Value() {
                return value_;
            }

            value_type *Value() const {
                return value_;
            }

            void Increment() {
                count_owners_++;
            }

            void Decrement() {
                count_owners_--;
            }

            std::size_t CountOwners() const {
                return count_owners_;
            }

         private:
            value_type* value_;
            std::size_t count_owners_ = 0;
        };

        Core *core;
    };
}