from __future__ import annotations

import datetime
import logging
import os
import pickle
import platform
import random
import string
import tempfile
from collections.abc import Iterator
from os import makedirs, path
from threading import Lock
from typing import Generic, NoReturn, TypeVar

from filelock import FileLock, Timeout
from typing_extensions import NamedTuple, Self

from . import constants
from .constants import ENV_VAR_TEST_MODE

now = datetime.datetime.now
getmtime = os.path.getmtime

T = TypeVar("T")

logger = logging.getLogger(__name__)

test_mode = os.getenv(ENV_VAR_TEST_MODE, "").lower() == "true"


class CacheEntry(NamedTuple, Generic[T]):
    expiry: datetime.datetime
    entry: T


K = TypeVar("K")
V = TypeVar("V")


def is_expired(d: datetime.datetime) -> bool:
    return now() >= d


class SFDictCache(Generic[K, V]):
    """A generic in-memory cache that acts somewhat like a dictionary.

    Unlike normal dictionaries keys(), values() and items() return list materialized
    at call time, unlike normal dictionaries that return a view object.
    """

    def __init__(
        self,
        entry_lifetime: int = constants.DAY_IN_SECONDS,
    ) -> None:
        """Inits a SFDictCache with lifetime."""
        self._entry_lifetime = datetime.timedelta(seconds=entry_lifetime)
        self._cache: dict[K, CacheEntry[V]] = {}
        self._lock = Lock()
        self._reset_telemetry()

    def __len__(self) -> int:
        with self._lock:
            return len(self._cache)

    @classmethod
    def from_dict(
        cls,
        _dict: dict[K, V],
        **kw,
    ) -> Self:
        """Create an dictionary cache from an already existing dictionary.

        Note that the same references will be stored in the cache than in
        the dictionary provided.
        """
        cache = cls(**kw)
        cache.update(_dict)
        return cache

    def _getitem(
        self,
        k: K,
        *,
        should_record_hits: bool = True,
    ) -> V:
        """Non-locking version of __getitem__.

        This should only be used by internal functions when already
        holding self._lock.
        """
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        try:
            t, v = self._cache[k]
        except KeyError:
            self._miss(k)
            raise
        if is_expired(t):
            self._expire(k)
        if should_record_hits:
            self._hit(k)
        return v

    # aliasing _getitem to unify the api with SFDictFileCache
    _getitem_non_locking = _getitem

    def _setitem(self, k: K, v: V) -> None:
        """Non-locking version of __setitem__.

        This should only be used by internal functions when already
        holding self._lock.
        """
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        self._cache[k] = CacheEntry(
            expiry=now() + self._entry_lifetime,
            entry=v,
        )
        self._add_or_remove()

    def __getitem__(
        self,
        k: K,
    ) -> V:
        """Returns an element if it hasn't expired yet in a thread-safe way."""
        with self._lock:
            return self._getitem(k, should_record_hits=True)

    def __setitem__(
        self,
        k: K,
        v: V,
    ) -> None:
        """Inserts an element in a thread-safe way."""
        with self._lock:
            self._setitem(k, v)

    def __iter__(self) -> Iterator[K]:
        return iter(self.keys())

    def keys(self) -> list[K]:
        return [k for k, _ in self.items()]

    def items(self) -> list[tuple[K, V]]:
        with self._lock:
            values: list[tuple[K, V]] = []
            for k in list(self._cache.keys()):
                try:
                    values.append((k, self._getitem(k, should_record_hits=False)))
                except KeyError:
                    pass
        return values

    def values(self) -> list[V]:
        return [v for _, v in self.items()]

    def get(
        self,
        k: K,
        default: V | None = None,
    ) -> V | None:
        try:
            return self[k]
        except KeyError:
            return default

    def clear(self) -> None:
        with self._lock:
            self._cache.clear()
            self._reset_telemetry()

    def _delitem(
        self,
        key: K,
    ) -> None:
        """Non-locking version of __delitem__.

        This should only be used by internal functions when already
        holding self._lock.
        """
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        del self._cache[key]
        self._add_or_remove()

    def __delitem__(
        self,
        key: K,
    ) -> None:
        with self._lock:
            self._delitem(key)

    def __contains__(
        self,
        key: K,
    ) -> bool:
        with self._lock:
            try:
                self._getitem(key, should_record_hits=True)
                return True
            except KeyError:
                # Fall through
                return False

    def _update(
        self,
        other: dict[K, V] | list[tuple[K, V]] | SFDictCache[K, V],
        update_newer_only: bool = False,
    ) -> bool:
        """Non-locking version of update.

        This should only be used by internal functions when already
        holding self._lock and other._lock.
        """
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        to_insert: dict[K, CacheEntry[V]]
        self._clear_expired_entries()
        if isinstance(other, (list, dict)):
            expiry = now() + self._entry_lifetime
            if isinstance(other, list):
                g = iter(other)
            elif isinstance(other, dict):
                g = iter(other.items())
            to_insert = {k: CacheEntry(expiry=expiry, entry=v) for k, v in g}
        elif isinstance(other, SFDictCache):
            other.clear_expired_entries()
            others_items = list(other._cache.items())
            # Only accept values from another cache if their key is not in self,
            #  or if expiry is later the self known one
            to_insert = {
                k: v
                for k, v in others_items
                if (
                    # self doesn't have this key
                    k not in self._cache
                    # we should update entries, regardless of whether they are newer
                    or (not update_newer_only)
                    # other has newer expiry time we want to update newer values only
                    or self._cache[k].expiry < v.expiry
                )
            }
        else:
            raise TypeError
        self._cache.update(to_insert)
        if to_insert:
            self._add_or_remove()
        # TODO: this should really save_if_should
        return len(to_insert) > 0

    def update(
        self,
        other: dict[K, V] | list[tuple[K, V]] | SFDictCache[K, V],
        update_newer_only: bool = False,
    ) -> bool:
        """Insert multiple values at the same time, if self could learn from the other.

        If this function is given a dictionary, or list expiration timestamps
        will be all the same a self._entry_lifetime form now. If it's
        given another SFDictCache then the timestamps will be taken
        from the other cache.

        Returns a boolean. It describes whether self learnt anything from other.

        Note that clear_expired_entries will be called on both caches. To
        prevent deadlocks this is done without acquiring other._lock. The
        intended behavior is to use this function with an unpickled/unused cache.
        If live caches are being merged then use .items() on them first and merge those
        into the other caches.
        """
        with self._lock:
            return self._update(other, update_newer_only)

    def update_newer(
        self,
        other: dict[K, V] | list[tuple[K, V]] | SFDictCache[K, V],
    ) -> bool:
        """This function is like update, but it only updates newer elements."""
        with self._lock:
            return self._update(
                other,
                update_newer_only=True,
            )

    def _clear_expired_entries(self) -> None:
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        cache_updated = False
        for k in list(self._cache.keys()):
            try:
                self._getitem(k, should_record_hits=False)
            except KeyError:
                # the only case KeyError raised in this method
                # is that k is expired
                cache_updated = True
        if cache_updated:
            self._add_or_remove()

    def clear_expired_entries(self) -> None:
        """Remove expired entries from the cache."""
        with self._lock:
            self._clear_expired_entries()

    # Telemetry related functions, these can be plugged by child classes
    def _reset_telemetry(self) -> None:
        """(Re)set telemetry fields.

        This function will be called by the initalizer and other functions that should
        reset telemtry entries.
        """
        self.telemetry = {
            "hit": 0,
            "miss": 0,
            "expiration": 0,
            "size": 0,
        }

    def _hit(self, k: K) -> None:
        """This function gets called when a hit occurs.

        Functions that hit every entry (like values) is not going to count.

        Note that while this function does not interact with lock, but it's only
        called from contexts where the lock is already held.
        """
        self.telemetry["hit"] += 1

    def _miss(self, k: K) -> None:
        """This function gets called when a miss occurs.

        Note that while this function does not interact with lock, but it's only
        called from contexts where the lock is already held.
        """
        self.telemetry["miss"] += 1

    def _expiration(self, k: K) -> None:
        """This function gets called when an expiration occurs.

        Note that while this function does not interact with lock, but it's only
        called from contexts where the lock is already held.
        """
        self.telemetry["expiration"] += 1

    def _expire(self, k: K) -> NoReturn:
        """Helper function to call _expiration and delete an item."""
        self._expiration(k)
        self._delitem(k)
        raise KeyError

    def _add_or_remove(self) -> None:
        """This function gets called when an element is added, or removed.

        Note that while this function does not interact with lock, but it's only
        called from contexts where the lock is already held.
        """
        self.telemetry["size"] = len(self._cache)


class SFDictFileCache(SFDictCache):
    # This number decides the chance of saving after writing (probability: 1/n+1)
    MAX_RAND_INT = 9
    _ATTRIBUTES_TO_PICKLE = (
        "_entry_lifetime",
        "_cache",
        "telemetry",
        "file_path",
        "file_timeout",
        "_file_lock_path",
        "last_loaded",
    )

    def __init__(
        self,
        file_path: str | dict[str, str],
        entry_lifetime: int = constants.DAY_IN_SECONDS,
        file_timeout: int = 0,
        load_if_file_exists: bool = True,
    ) -> None:
        """Inits an SFDictFileCache with path, lifetime.

        File path can be a dictionary that contains different paths for different OSes,
        possible keys are: 'darwin', 'linux' and 'windows'. If a current platform
        cannot be determined, or is not in the dictionary we'll use the first value.

        Once we select a location based on file path, we write and read a random
        temporary file to check for read/write permissions. If this fails OSError might
        be thrown.
        """
        super().__init__(
            entry_lifetime=entry_lifetime,
        )
        if isinstance(file_path, str):
            self.file_path = os.path.expanduser(file_path)
        else:
            current_platform = platform.system().lower()
            if current_platform is None or current_platform not in file_path:
                self.file_path = next(iter(file_path.values()))
            else:
                self.file_path = os.path.expanduser(file_path[current_platform])
        # Once we decided on where to put the file cache make sure that this
        #  place is readable/writable by us
        random_string = "".join(random.choice(string.ascii_letters) for _ in range(5))
        cache_folder = os.path.dirname(self.file_path)
        if not path.exists(cache_folder):
            try:
                makedirs(cache_folder, mode=0o700)
            except Exception as ex:
                logger.debug(
                    "cannot create a cache directory: [%s], err=[%s]",
                    cache_folder,
                    ex,
                )

        try:
            tmp_file, tmp_file_path = tempfile.mkstemp(
                dir=cache_folder,
            )
        except OSError as o_err:
            raise PermissionError(
                o_err.errno,
                "Cache folder is not writeable",
                cache_folder,
            )
        try:
            with open(tmp_file, "w") as w_file:
                # If mkstemp didn't fail this shouldn't throw an error
                w_file.write(random_string)
            try:
                with open(tmp_file_path) as r_file:
                    if r_file.read() != random_string:
                        Exception("Temporary file just written has wrong content")
            except OSError as o_err:
                raise PermissionError(
                    o_err.errno,
                    "Cache file is not readable",
                    tmp_file_path,
                )
        finally:
            if os.path.exists(tmp_file_path) and os.path.isfile(tmp_file_path):
                os.unlink(tmp_file_path)
        self.file_timeout = file_timeout
        self._file_lock_path = f"{self.file_path}.lock"
        self._file_lock = FileLock(self._file_lock_path, timeout=self.file_timeout)
        self.last_loaded: datetime.datetime | None = None
        if os.path.exists(self.file_path) and load_if_file_exists:
            with self._lock:
                self._load()
        # indicate whether the cache is modified or not, this variable is for
        # SFDictFileCache to determine whether to dump cache to file when _save is called
        self._cache_modified = False

    def _getitem_non_locking(
        self,
        k: K,
        *,
        should_record_hits: bool = True,
    ) -> V:
        """Non-locking version of __getitem__ of SFDictFileCache.

        This should only be used by internal functions when already
        holding self._lock.

        Note that we do not overwrite _getitem because _getitem is used by
        self._load to clear in-memory expired caches. Overwriting would cause
        infinite recursive call.
        """
        if k not in self._cache:
            loaded = self._load_if_should()
            if (not loaded) or k not in self._cache:
                self._miss(k)
                raise KeyError
        t, v = self._cache[k]
        if is_expired(t):
            loaded = self._load_if_should()
            expire_item = True
            if loaded:
                t, v = self._cache[k]
                expire_item = is_expired(t)
            if expire_item:
                # Raises KeyError
                self._expire(k)
        self._hit(k)
        return v

    def __getitem__(self, k: K) -> V:
        """Returns an element if it hasn't expired yet in a thread-safe way."""
        with self._lock:
            return self._getitem_non_locking(k)

    def _setitem(self, k: K, v: V) -> None:
        super()._setitem(k, v)
        self._save_if_should()

    def _load(self) -> bool:
        """Load cache from disk if possible, returns whether it was able to load."""
        try:
            with open(self.file_path, "rb") as r_file:
                other: SFDictFileCache = self._deserialize(r_file)
            # Since we want to know whether we are dirty after loading
            #  we have to know whether the file could learn anything from self
            #  so instead of calling self.update we call other.update and swap
            #  the 2 underlying caches after.
            self._lock.release()
            cache_file_learnt = other.update(
                self,
                update_newer_only=True,
            )
            self._lock.acquire()
            self._cache, other._cache = other._cache, self._cache
            self.telemetry["size"] = other.telemetry["size"]
            self._cache_modified = cache_file_learnt
            self.last_loaded = now()
            return True
        except (AssertionError, RuntimeError):
            raise
        except Exception as e:
            logger.debug("Fail to read cache from disk due to error: %s", e)
            return False

    def load(self) -> bool:
        """Load cache from disk if possible, returns whether it was able to load.

        This is the public version of _load, it makes sure that all the
        necessary locks are acquired.
        """
        with self._lock:
            return self._load()

    def _serialize(self):
        return pickle.dumps(self)

    @classmethod
    def _deserialize(cls, r_file):
        return pickle.load(r_file)

    def _save(self, load_first: bool = True, force_flush: bool = False) -> bool:
        """Save cache to disk if possible, returns whether it was able to save.

        This function is non-locking when it comes to self._lock.
        """
        if test_mode:
            assert (
                self._lock.locked()
            ), "The mutex self._lock should be locked by this thread"
        self._clear_expired_entries()
        if not self._cache_modified and not force_flush:
            # cache is not updated, so there is no need to dump cache to file, we just return
            return False
        try:
            with self._file_lock:
                if load_first:
                    self._load_if_should()
                _dir, fname = os.path.split(self.file_path)
                try:
                    tmp_file, tmp_file_path = tempfile.mkstemp(
                        prefix=fname,
                        dir=_dir,
                    )
                    # tmp_file is an opened OS level handle, which means we need to close it manually.
                    # https://docs.python.org/3/library/tempfile.html#tempfile.mkstemp
                    # ideally we shall just use the tmp_file fd to write,
                    # however, using os.write(tmp_file, bytes) causes seg fault during garbage collection when exiting
                    # python program.
                    # thus we fall back to the approach using the normal open() method to open a file and write.
                    with open(tmp_file, "wb") as w_file:
                        w_file.write(self._serialize())
                    # We write to a tmp file and then move it to have atomic write
                    os.replace(tmp_file_path, self.file_path)
                    self.last_loaded = datetime.datetime.fromtimestamp(
                        getmtime(self.file_path),
                    )
                    # after update, reset self._cache_modified to indicate it's up-to-update to avoid unnecessary flush
                    self._cache_modified = False
                    return True
                except NameError:
                    # note: when exiting python program, garbage collection will kick in
                    # leading to `open` being garbage collected,
                    # calling `open` raises NameError, we close the tmp file fd here to release the tmp file fd
                    try:
                        os.close(tmp_file)
                    except OSError:
                        pass
                except OSError as o_err:
                    raise PermissionError(
                        o_err.errno,
                        "Cache folder is not writeable",
                        _dir,
                    )
                finally:
                    if os.path.exists(tmp_file_path) and os.path.isfile(tmp_file_path):
                        os.unlink(tmp_file_path)
        except Timeout:
            logger.debug(
                f"acquiring {self._file_lock_path} timed out, skipping saving..."
            )
        except (AssertionError, RuntimeError):
            raise
        except Exception as e:
            logger.debug("Fail to write cache to disk due to error: %s", e)
        return False

    def save(self, load_first: bool = True) -> bool:
        """Save cache to disk if possible, returns whether it was able to save.

        This is the public version of _save, it makes sure that all the
        necessary locks are acquired.
        """
        with self._lock:
            return self._save(load_first)

    def _save_if_should(self) -> bool:
        """Saves file to disk if necessary and returns whether it saved.

        Uses self._should_save to decide whether to save.
        """
        if self._should_save():
            return self._save()
        return False

    def _load_if_should(self) -> bool:
        """Load file to disk if necessary and returns whether it loaded.

        Uses self._should_load to decide whether to load.
        """
        if self._should_load():
            return self._load()
        return False

    def _should_save(self) -> bool:
        """Decide whether we should save.

        This is a simple random number generator to randomize writes across processes
        that are possibly saving the same values in this cache.
        """
        return random.randint(0, self.MAX_RAND_INT) == 0

    def _should_load(self) -> bool:
        """Decide whether we should load.

        We should load if the file on disk has changed since we have last read it.
        """
        if os.path.exists(self.file_path) and os.path.isfile(self.file_path):
            if self.last_loaded is None:
                return True
            return (
                datetime.datetime.fromtimestamp(
                    getmtime(self.file_path),
                )
                > self.last_loaded
            )
        return False

    def clear(self) -> None:
        super().clear()
        # This unlink prevents us from loading just before saving
        with self._file_lock:
            if os.path.exists(self.file_path) and os.path.isfile(self.file_path):
                os.unlink(self.file_path)
        # TODO: is this necessary?
        with self._lock:
            self._save(load_first=False, force_flush=True)

    # Custom pickling implementation

    def __getstate__(self) -> dict:
        return {
            k: v
            for k, v in self.__dict__.items()
            if k in SFDictFileCache._ATTRIBUTES_TO_PICKLE
        }

    def __setstate__(self, state: dict) -> None:
        self.__dict__.update(state)
        self._cache_modified = False
        self._lock = Lock()
        self._file_lock = FileLock(self._file_lock_path, timeout=self.file_timeout)

    def _add_or_remove(self) -> None:
        """This function gets called when an element is added, or removed.

        Note that while this function does not interact with lock, but it's only
        called from contexts where the lock is already held.
        """
        super()._add_or_remove()
        self._cache_modified = True
