#pragma once

#include <memory>
#include "Test.hpp"

namespace smart_pointer {
class exception : std::exception {
    using base_class = std::exception;
    using base_class::base_class;
};

template<typename T, typename Allocator>
class SmartPointer {
    ENABLE_CLASS_TESTS;

 public:
    using value_type = T;

    explicit SmartPointer(value_type *p = nullptr) {
        if (p == nullptr) {
            core = nullptr;
        } else {
            core = new Core;
            core->myObj = p;
            core->pCount = 1;
        }
    }

    ~SmartPointer() {
        if (core == nullptr) {
            delete core;
        } else if (core->pCount > 1) {
            core->pCount -= 1;
        } else {
            if (core->myObj != nullptr) {
                delete core->myObj;
            }
            delete core;
        }
    }

    // move constructor
    SmartPointer(SmartPointer &&p) {
        if (core != nullptr) {
            if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                delete core->myObj;
                delete core;
                core = new Core;
            }
        }
        core = p.core;
        if (p.core != nullptr)
            p.core = nullptr;
    }

    // copy constructor
    SmartPointer(const SmartPointer &p) {
        if (core != nullptr) {
            if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                delete core->myObj;
                delete core;
                core = new Core;
            }
        }
        core = p.core;
        if (core != nullptr)
            core->pCount += 1;
    }

    SmartPointer &operator=(SmartPointer &&p) {
        if (core != nullptr) {
            if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                delete core->myObj;
                delete core;
            }
        }
        core = p.core;
        if (core != nullptr) {
            p.core = nullptr;
        }
        return *this;
    }

    SmartPointer &operator=(const SmartPointer &p) {
        if (core != nullptr) {
            if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                delete this->core->myObj;
                delete this->core;
                core = nullptr;
                // core = new Core;
            }
        }
        core = p.core;
        if (core != nullptr)
            core->pCount += 1;
        return *this;
    }

    SmartPointer &operator=(value_type *p) {
        if (core != nullptr) {
            if (core->pCount > 1) {
                core->pCount -= 1;
            } else {
                delete core->myObj;
                delete core;
                core = new Core;
            }
        } else {
            core = new Core;
        }
        if (p != nullptr) {
            core->myObj = p;
        } else {
            delete core;
            core = nullptr;
        }
        if (core != nullptr)
            core->pCount = 1;
        return *this;
    }

    value_type &operator*() {
        if (get() == nullptr) {
            throw exception();
        } else {
            return *core->myObj;
        }
    }

    const value_type &operator*() const {
        if (get() == nullptr) {
            throw exception();
        } else {
            return *core->myObj;
        }
    }

    value_type *operator->() const {
        return core ? core->myObj : nullptr;
    }

    value_type *get() const {
        return core ? core->myObj : nullptr;
    }

    operator bool() const {
        return core != nullptr;
    }

    template<typename U, typename AnotherAllocator>
    bool operator==(const SmartPointer<U, AnotherAllocator> &p) const {
        return (p.get() == nullptr && get() == nullptr) ||
        (p.get() != nullptr && get() != nullptr &&
        *p.get() == *get() &&
        static_cast<void *>(p.get()) == static_cast<void *>(get()));
    }

    template<typename U, typename AnotherAllocator>
    bool operator!=(const SmartPointer<U, AnotherAllocator> &p) const {
        return !((p.get() == nullptr && get() == nullptr) ||
        (p.get() != nullptr && get() != nullptr &&
        *p.get() == *get() &&
        static_cast<void *>(p.get()) ==
        static_cast<void *>(get())));
    }

    [[nodiscard]] std::size_t count_owners() const {
        if (core != nullptr)
            return core->pCount;
        return 0;
    }

 private:
    class Core {
     public:
        value_type *myObj;
        std::size_t pCount;
    };

    Core *core = nullptr;
};
}  // namespace smart_pointer