# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Final, cast
from urllib import parse

from streamlit.errors import StreamlitAPIException, StreamlitQueryParamDictValueError
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx

if TYPE_CHECKING:
    from _typeshed import SupportsKeysAndGetItem

QueryParamValue = str | Iterable[str]
QueryParamsInput = Mapping[str, QueryParamValue] | Iterable[tuple[str, QueryParamValue]]


EMBED_QUERY_PARAM: Final[str] = "embed"
EMBED_OPTIONS_QUERY_PARAM: Final[str] = "embed_options"
EMBED_QUERY_PARAMS_KEYS: Final[list[str]] = [
    EMBED_QUERY_PARAM,
    EMBED_OPTIONS_QUERY_PARAM,
]


@dataclass
class QueryParams(MutableMapping[str, str]):
    """A lightweight wrapper of a dict that sends forwardMsgs when state changes.
    It stores str keys with str and List[str] values.
    """

    _query_params: dict[str, list[str] | str] = field(default_factory=dict)

    def __iter__(self) -> Iterator[str]:
        self._ensure_single_query_api_used()

        return iter(
            key
            for key in self._query_params
            if key.lower() not in EMBED_QUERY_PARAMS_KEYS
        )

    def __getitem__(self, key: str) -> str:
        """Retrieves a value for a given key in query parameters.
        Returns the last item in a list or an empty string if empty.
        If the key is not present, raise KeyError.
        """
        self._ensure_single_query_api_used()
        if key.lower() in EMBED_QUERY_PARAMS_KEYS:
            raise KeyError(missing_key_error_message(key))

        try:
            value = self._query_params[key]
            if isinstance(value, list):
                if len(value) == 0:
                    return ""
                # Return the last value to mimic Tornado's behavior
                # https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
                return value[-1]
            return value
        except KeyError:
            raise KeyError(missing_key_error_message(key))

    def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
        self._ensure_single_query_api_used()
        self._set_item_internal(key, value)
        self._send_query_param_msg()

    def _set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
        _set_item_in_dict(self._query_params, key, value)

    def __delitem__(self, key: str) -> None:
        self._ensure_single_query_api_used()
        if key.lower() in EMBED_QUERY_PARAMS_KEYS:
            raise KeyError(missing_key_error_message(key))
        try:
            del self._query_params[key]
            self._send_query_param_msg()
        except KeyError:
            raise KeyError(missing_key_error_message(key))

    def update(
        self,
        other: Iterable[tuple[str, str | Iterable[str]]]
        | SupportsKeysAndGetItem[str, str | Iterable[str]] = (),
        /,
        **kwds: str,
    ) -> None:
        # This overrides the `update` provided by MutableMapping
        # to ensure only one one ForwardMsg is sent.
        self._ensure_single_query_api_used()
        if hasattr(other, "keys") and hasattr(other, "__getitem__"):
            other = cast("SupportsKeysAndGetItem[str, str | Iterable[str]]", other)
            for key in other.keys():  # noqa: SIM118
                self._set_item_internal(key, other[key])
        else:
            for key, value in other:
                self._set_item_internal(key, value)
        for key, value in kwds.items():
            self._set_item_internal(key, value)
        self._send_query_param_msg()

    def get_all(self, key: str) -> list[str]:
        self._ensure_single_query_api_used()
        if key not in self._query_params or key.lower() in EMBED_QUERY_PARAMS_KEYS:
            return []
        value = self._query_params[key]
        return value if isinstance(value, list) else [value]

    def __len__(self) -> int:
        self._ensure_single_query_api_used()
        return len(
            {
                key
                for key in self._query_params
                if key.lower() not in EMBED_QUERY_PARAMS_KEYS
            }
        )

    def __str__(self) -> str:
        self._ensure_single_query_api_used()
        return str(self._query_params)

    def _send_query_param_msg(self) -> None:
        ctx = get_script_run_ctx()
        if ctx is None:
            return
        self._ensure_single_query_api_used()

        msg = ForwardMsg()
        msg.page_info_changed.query_string = parse.urlencode(
            self._query_params, doseq=True
        )
        ctx.query_string = msg.page_info_changed.query_string
        ctx.enqueue(msg)

    def clear(self) -> None:
        self._ensure_single_query_api_used()
        self.clear_with_no_forward_msg(preserve_embed=True)
        self._send_query_param_msg()

    def to_dict(self) -> dict[str, str]:
        self._ensure_single_query_api_used()
        # return the last query param if multiple values are set
        return {
            key: self[key]
            for key in self._query_params
            if key.lower() not in EMBED_QUERY_PARAMS_KEYS
        }

    def from_dict(
        self,
        _dict: Iterable[tuple[str, str | Iterable[str]]]
        | SupportsKeysAndGetItem[str, str | Iterable[str]],
    ) -> None:
        self._ensure_single_query_api_used()
        old_value = self._query_params.copy()
        self.clear_with_no_forward_msg(preserve_embed=True)
        try:
            self.update(_dict)
        except StreamlitAPIException:
            # restore the original from before we made any changes.
            self._query_params = old_value
            raise

    def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
        self._query_params[key] = val

    def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
        self._query_params = {
            key: value
            for key, value in self._query_params.items()
            if key.lower() in EMBED_QUERY_PARAMS_KEYS and preserve_embed
        }

    def _ensure_single_query_api_used(self) -> None:
        ctx = get_script_run_ctx()
        if ctx is None:
            return
        ctx.mark_production_query_params_used()


def missing_key_error_message(key: str) -> str:
    return f'st.query_params has no key "{key}".'


def _set_item_in_dict(
    target_dict: dict[str, list[str] | str], key: str, value: str | Iterable[str]
) -> None:
    """Set an item in a dictionary."""
    if isinstance(value, dict):
        raise StreamlitQueryParamDictValueError(key)

    if key.lower() in EMBED_QUERY_PARAMS_KEYS:
        raise StreamlitAPIException(
            "Query param embed and embed_options (case-insensitive) cannot be set programmatically."
        )
    # Type checking users should handle the string serialization themselves
    # We will accept any type for the list and serialize to str just in case
    if isinstance(value, Iterable) and not isinstance(value, str):
        target_dict[key] = [str(item) for item in value]
    else:
        target_dict[key] = str(value)


def process_query_params(
    query_params: Iterable[tuple[str, str | Iterable[str]]]
    | SupportsKeysAndGetItem[str, str | Iterable[str]],
) -> str:
    """Convert query params into a URL-encoded query string."""
    processed_params: dict[str, list[str] | str] = {}

    if hasattr(query_params, "keys") and hasattr(query_params, "__getitem__"):
        query_params = cast(
            "SupportsKeysAndGetItem[str, str | Iterable[str]]", query_params
        )
        for key in query_params.keys():  # noqa: SIM118
            value = query_params[key]
            _set_item_in_dict(processed_params, key, value)
    else:
        for key, value in query_params:
            if key in processed_params:
                # If the key already exists, we need to accumulate the values.
                if isinstance(value, dict):
                    raise StreamlitQueryParamDictValueError(key)

                current_val = processed_params[key]
                if not isinstance(current_val, list):
                    current_val = [current_val]

                if isinstance(value, Iterable) and not isinstance(value, str):
                    current_val.extend([str(item) for item in value])
                else:
                    current_val.append(str(value))

                processed_params[key] = current_val
            else:
                _set_item_in_dict(processed_params, key, value)

    return parse.urlencode(processed_params, doseq=True)
