#pragma once

#include <memory>
#include "Test.hpp"

namespace smart_pointer {
    // `exception` class definition
    class exception : std::exception {
        using base_class = std::exception;
        using base_class::base_class;
    };

    // `SmartPointer` class declaration
    template<
        typename T,
        typename Allocator
    >
    class SmartPointer {
        // don't remove this macro
        ENABLE_CLASS_TESTS;

    public:

        explicit SmartPointer(T *p = nullptr) {
            if (p == nullptr) {
                core = nullptr;
            } else {
                core = new Core;
                core->myObj = p;
                core->pCount = 1;
            }
        }

        // copy constructor
        SmartPointer(const SmartPointer &p) {
            if (core != nullptr) {
                if (core->pCount > 1) {
                    core->pCount -= 1;
                } else {
                    delete core->myObj;
                    delete core;
                    core = new Core;
                }
            }
            core = p.core;
            if (core != nullptr)
                core->pCount += 1;
        }

        // move constructor
        SmartPointer(SmartPointer &&p) {
            if (core != nullptr) {
                if (core->pCount > 1) {
                    core->pCount -= 1;
                } else {
                    delete core->myObj;
                    delete core;
                    core = new Core;
                }
            }
            core = p.core;
            if (p.core != nullptr)
                p.core = nullptr;
        }

        ~SmartPointer() {
            if (core == nullptr) {
                delete core;
            } else if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                if (core->myObj != nullptr) {
                    delete core->myObj;
                }
                delete core;
            }
        }

        // move assigment
        SmartPointer &operator=(SmartPointer &&p) {
            if (core != nullptr) {
                if (core->pCount > 1) {
                    core->pCount -= 1;
                } else {
                    delete core->myObj;
                    delete core;
                }
            }
            core = p.core;
            if (core != nullptr) {
                p.core = nullptr;
            }
            return *this;
        }

        // copy assigment
        SmartPointer &operator=(const SmartPointer &p) {
            if (core != nullptr) {
                if (core->pCount > 1) {
                    core->pCount -= 1;
                } else {
                    delete this->core->myObj;
                    delete this->core;
                    core = nullptr;
                    // core = new Core;
                }
            }
            core = p.core;
            if (core != nullptr)
                core->pCount += 1;
            return *this;
        }

        SmartPointer &operator=(T *p) {
            if (core != nullptr) {
                if (core->pCount > 1) {
                    core->pCount -= 1;
                } else {
                    delete core->myObj;
                    delete core;
                    core = new Core;
                }
            } else {
                core = new Core;
            }
            if (p != nullptr) {
                core->myObj = p;
            } else {
                delete core;
                core = nullptr;
            }
            if (core != nullptr)
                core->pCount = 1;
            return *this;
        }

        // return reference to the object of class/type T
        // if SmartPointer contains nullptr throw `SmartPointer::exception`
        T &operator*() {
            if (get() == nullptr) {
                throw exception();
            } else {
                return *core->myObj;
            }
        }

        const T &operator*() const {
            if (get() == nullptr) {
                throw exception();
            } else {
                return *core->myObj;
            }
        }

        // return pointer to the object of class/type T
        T *operator->() const {
            return core ? core->myObj : nullptr;
        }

        T *get() const {
            return core ? core->myObj : nullptr;
        }

        // if pointer == nullptr => return false
        operator bool() const {
            return core != nullptr;
        }

        // if pointers points to the same address or both null => true
        template<typename U, typename AnotherAllocator>
        bool operator==(const SmartPointer<U, AnotherAllocator> &p) const {
            return (p.get() == nullptr && get() == nullptr) ||
            (p.get() != nullptr && get() != nullptr &&
            *p.get() == *get() &&
            static_cast<void *>(p.get()) == static_cast<void *>(get()));
        }

        // if pointers points to the same address or both null => false
        template<typename U, typename AnotherAllocator>
        bool operator!=(const SmartPointer<U, AnotherAllocator> &p) const {
            return !((p.get() == nullptr && get() == nullptr) ||
            (p.get() != nullptr && get() != nullptr &&
            *p.get() == *get() &&
            static_cast<void *>(p.get()) ==
            static_cast<void *>(get())));
        }

        // if smart pointer contains non-nullptr => return count owners
        // if smart pointer contains nullptr => return 0
        [[nodiscard]] std::size_t count_owners() const {
            if (core != nullptr)
                return core->pCount;
            return 0;
        }

    private:

        class Core {
        public:
            T *myObj;
            std::size_t pCount;
        };

        Core *core = nullptr;
    };
}  // namespace smart_pointer