# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides global MediaFileManager object as `media_file_manager`."""

from __future__ import annotations

import collections
import threading
import uuid
from typing import TYPE_CHECKING, BinaryIO, Final, TextIO, TypedDict, cast

from streamlit.logger import get_logger
from streamlit.runtime.download_data_util import convert_data_to_bytes_and_infer_mime
from streamlit.runtime.media_file_storage import (
    MediaFileKind,
    MediaFileStorage,
    MediaFileStorageError,
)

if TYPE_CHECKING:
    import io
    from collections.abc import Callable

_LOGGER: Final = get_logger(__name__)


class DeferredCallableEntry(TypedDict):
    """Typed metadata for deferred download callables."""

    callable: Callable[[], bytes | str | BinaryIO | TextIO | io.RawIOBase]
    mimetype: str | None
    filename: str | None
    coordinates: str


def _get_session_id() -> str:
    """Get the active AppSession's session_id."""
    from streamlit.runtime.scriptrunner_utils.script_run_context import (
        get_script_run_ctx,
    )

    ctx = get_script_run_ctx()
    if ctx is None:
        # This is only None when running "python myscript.py" rather than
        # "streamlit run myscript.py". In which case the session ID doesn't
        # matter and can just be a constant, as there's only ever "session".
        return "dontcare"
    return ctx.session_id


class MediaFileMetadata:
    """Metadata that the MediaFileManager needs for each file it manages."""

    def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA) -> None:
        self._kind = kind
        self._is_marked_for_delete = False

    @property
    def kind(self) -> MediaFileKind:
        return self._kind

    @property
    def is_marked_for_delete(self) -> bool:
        return self._is_marked_for_delete

    def mark_for_delete(self) -> None:
        self._is_marked_for_delete = True


class MediaFileManager:
    """In-memory file manager for MediaFile objects.

    This keeps track of:
    - Which files exist, and what their IDs are. This is important so we can
      serve files by ID -- that's the whole point of this class!
    - Which files are being used by which AppSession (by ID). This is
      important so we can remove files from memory when no more sessions need
      them.
    - The exact location in the app where each file is being used (i.e. the
      file's "coordinates"). This is is important so we can mark a file as "not
      being used by a certain session" if it gets replaced by another file at
      the same coordinates. For example, when doing an animation where the same
      image is constantly replace with new frames. (This doesn't solve the case
      where the file's coordinates keep changing for some reason, though! e.g.
      if new elements keep being prepended to the app. Unlikely to happen, but
      we should address it at some point.)
    """

    def __init__(self, storage: MediaFileStorage) -> None:
        self._storage = storage

        # Dict of [file_id -> MediaFileMetadata]
        self._file_metadata: dict[str, MediaFileMetadata] = {}

        # Dict[session ID][coordinates] -> file_id.
        self._files_by_session_and_coord: dict[str, dict[str, str]] = (
            collections.defaultdict(dict)
        )

        # Dict of [file_id -> deferred callable metadata]
        # Used for deferred download button execution
        self._deferred_callables: dict[str, DeferredCallableEntry] = {}

        # MediaFileManager is used from multiple threads, so all operations
        # need to be protected with a Lock. (This is not an RLock, which
        # means taking it multiple times from the same thread will deadlock.)
        self._lock = threading.Lock()

    def _get_inactive_file_ids(self) -> set[str]:
        """Compute the set of files that are stored in the manager, but are
        not referenced by any active session. These are files that can be
        safely deleted.

        Thread safety: callers must hold `self._lock`.
        """
        # Get the set of all our file IDs.
        file_ids = set(self._file_metadata.keys())

        # Subtract all IDs that are in use by each session
        for session_file_ids_by_coord in self._files_by_session_and_coord.values():
            file_ids.difference_update(session_file_ids_by_coord.values())

        return file_ids

    def remove_orphaned_files(self) -> None:
        """Remove all files that are no longer referenced by any active session.

        Safe to call from any thread.
        """
        _LOGGER.debug("Removing orphaned files...")

        with self._lock:
            for file_id in self._get_inactive_file_ids():
                file = self._file_metadata[file_id]
                if file.kind == MediaFileKind.MEDIA:
                    self._delete_file(file_id)
                elif file.kind == MediaFileKind.DOWNLOADABLE:
                    if file.is_marked_for_delete:
                        self._delete_file(file_id)
                    else:
                        file.mark_for_delete()

            # Clean up orphaned deferred callables
            self._remove_orphaned_deferred_callables()

    def _remove_orphaned_deferred_callables(self) -> None:
        """Remove deferred callables that are not referenced by any active session.

        Thread safety: callers must hold `self._lock`.
        """
        _LOGGER.debug("Removing orphaned deferred callables...")

        # Get all file_ids currently referenced by any session
        active_file_ids = set[str]()
        for session_file_ids_by_coord in self._files_by_session_and_coord.values():
            active_file_ids.update(session_file_ids_by_coord.values())

        # Remove deferred callables that are no longer referenced
        deferred_ids_to_remove = [
            file_id
            for file_id in self._deferred_callables
            if file_id not in active_file_ids
        ]
        for file_id in deferred_ids_to_remove:
            _LOGGER.debug("Removing deferred callable: %s", file_id)
            del self._deferred_callables[file_id]

    def _delete_file(self, file_id: str) -> None:
        """Delete the given file from storage, and remove its metadata from
        self._files_by_id.

        Thread safety: callers must hold `self._lock`.
        """
        _LOGGER.debug("Deleting File: %s", file_id)
        self._storage.delete_file(file_id)
        del self._file_metadata[file_id]

    def clear_session_refs(self, session_id: str | None = None) -> None:
        """Remove the given session's file references.

        (This does not remove any files from the manager - you must call
        `remove_orphaned_files` for that.)

        Should be called whenever ScriptRunner completes and when a session ends.

        Safe to call from any thread.
        """
        if session_id is None:
            session_id = _get_session_id()

        _LOGGER.debug("Disconnecting files for session with ID %s", session_id)

        with self._lock:
            if session_id in self._files_by_session_and_coord:
                del self._files_by_session_and_coord[session_id]

            # Don't immediately delete deferred callables here to avoid race conditions.
            # They will be cleaned up by remove_orphaned_deferred_callables() which
            # only removes callables that are not referenced by ANY session.

        _LOGGER.debug(
            "Sessions still active: %r", self._files_by_session_and_coord.keys()
        )

        _LOGGER.debug(
            "Files: %s; Sessions with files: %s",
            len(self._file_metadata),
            len(self._files_by_session_and_coord),
        )

    def add(
        self,
        path_or_data: bytes | str,
        mimetype: str,
        coordinates: str,
        file_name: str | None = None,
        is_for_static_download: bool = False,
    ) -> str:
        """Add a new MediaFile with the given parameters and return its URL.

        If an identical file already exists, return the existing URL
        and registers the current session as a user.

        Safe to call from any thread.

        Parameters
        ----------
        path_or_data : bytes or str
            If bytes: the media file's raw data. If str: the name of a file
            to load from disk.
        mimetype : str
            The mime type for the file. E.g. "audio/mpeg".
            This string will be used in the "Content-Type" header when the file
            is served over HTTP.
        coordinates : str
            Unique string identifying an element's location.
            Prevents memory leak of "forgotten" file IDs when element media
            is being replaced-in-place (e.g. an st.image stream).
            coordinates should be of the form: "1.(3.-14).5"
        file_name : str or None
            Optional file_name. Used to set the filename in the response header.
        is_for_static_download: bool
            Indicate that data stored for downloading as a file,
            not as a media for rendering at page. [default: False]

        Returns
        -------
        str
            The url that the frontend can use to fetch the media.

        Raises
        ------
        If a filename is passed, any Exception raised when trying to read the
        file will be re-raised.
        """

        session_id = _get_session_id()

        with self._lock:
            kind = (
                MediaFileKind.DOWNLOADABLE
                if is_for_static_download
                else MediaFileKind.MEDIA
            )
            file_id = self._storage.load_and_get_id(
                path_or_data, mimetype, kind, file_name
            )
            metadata = MediaFileMetadata(kind=kind)

            self._file_metadata[file_id] = metadata
            self._files_by_session_and_coord[session_id][coordinates] = file_id

            return self._storage.get_url(file_id)

    def add_deferred(
        self,
        data_callable: Callable[[], bytes | str | BinaryIO | TextIO | io.RawIOBase],
        mimetype: str | None,
        coordinates: str,
        file_name: str | None = None,
    ) -> str:
        """Register a callable for deferred execution. Returns placeholder file_id.

        The callable will be executed later when execute_deferred() is called,
        typically when the user clicks a download button.

        Safe to call from any thread.

        Parameters
        ----------
        data_callable : Callable[[], bytes | str | BinaryIO | TextIO | io.RawIOBase]
            A callable that returns the file data when invoked.
        mimetype : str or None
            The mime type for the file. E.g. "text/csv".
            If None, the mimetype will be inferred from the data type when
            execute_deferred() is called.
        coordinates : str
            Unique string identifying an element's location.
        file_name : str or None
            Optional file_name. Used to set the filename in the response header.

        Returns
        -------
        str
            A placeholder file_id that can be used to execute the callable later.
        """
        session_id = _get_session_id()

        with self._lock:
            # Generate a unique placeholder ID for this deferred callable
            # Expected: a new placeholder ID is created on every script rerun.
            file_id = uuid.uuid4().hex

            # Store the callable with its metadata
            self._deferred_callables[file_id] = cast(
                "DeferredCallableEntry",
                {
                    "callable": data_callable,
                    "mimetype": mimetype,
                    "filename": file_name,
                    "coordinates": coordinates,
                },
            )

            # Track this deferred file by session and coordinate
            self._files_by_session_and_coord[session_id][coordinates] = file_id

            return file_id

    def execute_deferred(self, file_id: str) -> str:
        """Execute a deferred callable and return the URL to the generated file.

        This method retrieves the callable registered with add_deferred(),
        executes it, stores the result, and returns a URL to access it.

        Safe to call from any thread.

        Parameters
        ----------
        file_id : str
            The placeholder file_id returned by add_deferred().

        Returns
        -------
        str
            The URL that can be used to download the generated file.

        Raises
        ------
        MediaFileStorageError
            If the file_id is not found or if the callable execution fails.
        """
        # Retrieve deferred callable metadata while holding lock
        with self._lock:
            if file_id not in self._deferred_callables:
                raise MediaFileStorageError(f"Deferred file {file_id} not found")

            deferred = self._deferred_callables[file_id]

        # Execute callable outside lock to avoid blocking other operations
        try:
            data = deferred["callable"]()
        except Exception as e:
            raise MediaFileStorageError(f"Callable execution failed: {e}") from e

        # Convert data to bytes and infer mimetype if needed
        data_as_bytes, inferred_mime_type = convert_data_to_bytes_and_infer_mime(
            data,
            unsupported_error=MediaFileStorageError(
                f"Callable returned unsupported type: {type(data)}"
            ),
        )

        # Use provided mimetype if available, otherwise use inferred mimetype
        mime_type: str = deferred["mimetype"] or inferred_mime_type

        # Store the generated data and get the actual file_id
        with self._lock:
            actual_file_id = self._storage.load_and_get_id(
                data_as_bytes,
                mime_type,
                MediaFileKind.DOWNLOADABLE,
                deferred["filename"],
            )

            # Create metadata for the actual file
            metadata = MediaFileMetadata(kind=MediaFileKind.DOWNLOADABLE)
            self._file_metadata[actual_file_id] = metadata

            # Keep the deferred callable so users can download multiple times
            # It will be cleaned up when clear_session_refs() is called on rerun

            # We leave actual_file_id unmapped so repeat clicks rerun the callable.
            # Cleanup prunes the stored file once no session references it.

            # Return the URL to access the file
            return self._storage.get_url(actual_file_id)
