# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    Final,
    Literal,
    TypeAlias,
    TypedDict,
    Union,
    cast,
    overload,
)

from typing_extensions import Required

from streamlit import type_util
from streamlit.deprecation_util import (
    make_deprecated_name_warning,
    show_deprecation_warning,
)
from streamlit.elements.lib.form_utils import current_form_id
from streamlit.elements.lib.layout_utils import (
    Height,
    LayoutConfig,
    Width,
    validate_height,
    validate_width,
)
from streamlit.elements.lib.policies import check_widget_policies
from streamlit.elements.lib.streamlit_plotly_theme import (
    configure_streamlit_plotly_theme,
)
from streamlit.elements.lib.utils import Key, compute_and_register_element_id, to_key
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.PlotlyChart_pb2 import PlotlyChart as PlotlyChartProto
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.runtime.state import WidgetCallback, register_widget
from streamlit.util import AttributeDictionary

if TYPE_CHECKING:
    from collections.abc import Iterable

    import matplotlib as mpl
    import plotly.graph_objs as go
    from plotly.basedatatypes import BaseFigure

    from streamlit.delta_generator import DeltaGenerator

# We need to configure the Plotly theme before any Plotly figures are created:
configure_streamlit_plotly_theme()

_AtomicFigureOrData: TypeAlias = Union[
    "go.Figure",
    "go.Data",
]
FigureOrData: TypeAlias = Union[
    _AtomicFigureOrData,
    list[_AtomicFigureOrData],
    # It is kind of hard to figure out exactly what kind of dict is supported
    # here, as plotly hasn't embraced typing yet. This version is chosen to
    # align with the docstring.
    dict[str, _AtomicFigureOrData],
    "BaseFigure",
    "mpl.figure.Figure",
]

SelectionMode: TypeAlias = Literal["lasso", "points", "box"]
_SELECTION_MODES: Final[set[SelectionMode]] = {"lasso", "points", "box"}

_LOGGER: Final = get_logger(__name__)


class PlotlySelectionState(TypedDict, total=False):
    """
    The schema for the Plotly chart selection state.

    The selection state is stored in a dictionary-like object that supports both
    key and attribute notation. Selection states cannot be programmatically
    changed or set through Session State.

    Attributes
    ----------
    points : list[dict[str, Any]]
        The selected data points in the chart, including the data points
        selected by the box and lasso mode. The data includes the values
        associated to each point and a point index used to populate
        ``point_indices``. If additional information has been assigned to your
        points, such as size or legend group, this is also included.

    point_indices : list[int]
        The numerical indices of all selected data points in the chart. The
        details of each identified point are included in ``points``.

    box : list[dict[str, Any]]
        The metadata related to the box selection. This includes the
        coordinates of the selected area.

    lasso : list[dict[str, Any]]
        The metadata related to the lasso selection. This includes the
        coordinates of the selected area.

    Example
    -------
    When working with more complicated graphs, the ``points`` attribute
    displays additional information. Try selecting points in the following
    example:

    >>> import plotly.express as px
    >>> import streamlit as st
    >>>
    >>> df = px.data.iris()
    >>> fig = px.scatter(
    ...     df,
    ...     x="sepal_width",
    ...     y="sepal_length",
    ...     color="species",
    ...     size="petal_length",
    ...     hover_data=["petal_width"],
    ... )
    >>>
    >>> event = st.plotly_chart(fig, key="iris", on_select="rerun")
    >>>
    >>> event.selection

    .. output::
        https://doc-chart-events-plotly-selection-state.streamlit.app
        height: 600px

    This is an example of the selection state when selecting a single point:

    >>> {
    >>>   "points": [
    >>>     {
    >>>       "curve_number": 2,
    >>>       "point_number": 9,
    >>>       "point_index": 9,
    >>>       "x": 3.6,
    >>>       "y": 7.2,
    >>>       "customdata": [
    >>>         2.5
    >>>       ],
    >>>       "marker_size": 6.1,
    >>>       "legendgroup": "virginica"
    >>>     }
    >>>   ],
    >>>   "point_indices": [
    >>>     9
    >>>   ],
    >>>   "box": [],
    >>>   "lasso": []
    >>> }

    """

    points: Required[list[dict[str, Any]]]
    point_indices: Required[list[int]]
    box: Required[list[dict[str, Any]]]
    lasso: Required[list[dict[str, Any]]]


class PlotlyState(TypedDict, total=False):
    """
    The schema for the Plotly chart event state.

    The event state is stored in a dictionary-like object that supports both
    key and attribute notation. Event states cannot be programmatically
    changed or set through Session State.

    Only selection events are supported at this time.

    Attributes
    ----------
    selection : dict
        The state of the ``on_select`` event. This attribute returns a
        dictionary-like object that supports both key and attribute notation.
        The attributes are described by the ``PlotlySelectionState`` dictionary
        schema.

    Example
    -------
    Try selecting points by any of the three available methods (direct click,
    box, or lasso). The current selection state is available through Session
    State or as the output of the chart function.

    >>> import plotly.express as px
    >>> import streamlit as st
    >>>
    >>> df = px.data.iris()
    >>> fig = px.scatter(df, x="sepal_width", y="sepal_length")
    >>>
    >>> event = st.plotly_chart(fig, key="iris", on_select="rerun")
    >>>
    >>> event

    .. output::
        https://doc-chart-events-plotly-state.streamlit.app
        height: 600px

    """

    selection: Required[PlotlySelectionState]


@dataclass
class PlotlyChartSelectionSerde:
    """PlotlyChartSelectionSerde is used to serialize and deserialize the Plotly Chart
    selection state.
    """

    def deserialize(self, ui_value: str | None) -> PlotlyState:
        empty_selection_state: PlotlyState = {
            "selection": {
                "points": [],
                "point_indices": [],
                "box": [],
                "lasso": [],
            },
        }

        selection_state = (
            empty_selection_state
            if ui_value is None
            else cast("PlotlyState", AttributeDictionary(json.loads(ui_value)))
        )

        if "selection" not in selection_state:
            selection_state = empty_selection_state  # type: ignore[unreachable]

        return cast("PlotlyState", AttributeDictionary(selection_state))

    def serialize(self, selection_state: PlotlyState) -> str:
        return json.dumps(selection_state, default=str)


def parse_selection_mode(
    selection_mode: SelectionMode | Iterable[SelectionMode],
) -> set[PlotlyChartProto.SelectionMode.ValueType]:
    """Parse and check the user provided selection modes."""
    if isinstance(selection_mode, str):
        # Only a single selection mode was passed
        selection_mode_set = {selection_mode}
    else:
        # Multiple selection modes were passed
        selection_mode_set = set(selection_mode)

    if not selection_mode_set.issubset(_SELECTION_MODES):
        raise StreamlitAPIException(
            f"Invalid selection mode: {selection_mode}. "
            f"Valid options are: {_SELECTION_MODES}"
        )

    parsed_selection_modes = []
    for mode in selection_mode_set:
        if mode == "points":
            parsed_selection_modes.append(PlotlyChartProto.SelectionMode.POINTS)
        elif mode == "lasso":
            parsed_selection_modes.append(PlotlyChartProto.SelectionMode.LASSO)
        elif mode == "box":
            parsed_selection_modes.append(PlotlyChartProto.SelectionMode.BOX)
    return set(parsed_selection_modes)


def _resolve_content_width(width: Width, figure: Any) -> Width:
    """Resolve "content" width by inspecting the figure's layout width.

    For content width, we check if the plotly figure has an explicit width
    in its layout. If so, we use that as a pixel width. If not, we default
    to 700 pixels which matches the plotly.js default width.

    Args
    ----
    width : Width
        The original width parameter
    figure : Any
        The plotly figure object (Figure, dict, or other supported formats)

    Returns
    -------
    Width
        The resolved width (either original width, figure width as pixels, or 700)
    """

    if width != "content":
        return width

    # Extract width from the figure's layout
    # plotly.tools.mpl_to_plotly() returns Figure objects with .layout attribute
    # plotly.tools.return_figure_from_figure_or_data() returns dictionaries
    figure_width = None
    if isinstance(figure, dict):
        figure_width = figure.get("layout", {}).get("width")
    else:
        # Handle Figure objects from matplotlib conversion
        try:
            figure_width = figure.layout.width
        except (AttributeError, TypeError):
            _LOGGER.debug("Could not parse width from figure")

    if (
        figure_width is not None
        and isinstance(figure_width, (int, float))
        and figure_width > 0
    ):
        return int(figure_width)

    # Default to 700 pixels (plotly.js default) when no width is specified
    return 700


def _resolve_content_height(height: Height, figure: Any) -> Height:
    """Resolve "content" height by inspecting the figure's layout height.

    For content height, we check if the plotly figure has an explicit height
    in its layout. If so, we use that as a pixel height. If not, we default
    to 450 pixels which matches the plotly.js default height.

    Args
    ----
    height : Height
        The original height parameter
    figure : Any
        The plotly figure object (Figure, dict, or other supported formats)

    Returns
    -------
    Height
        The resolved height (either original height, figure height as pixels, or 450)
    """

    if height != "content":
        return height

    # Extract height from the figure's layout
    # plotly.tools.mpl_to_plotly() returns Figure objects with .layout attribute
    # plotly.tools.return_figure_from_figure_or_data() returns dictionaries
    figure_height = None
    if isinstance(figure, dict):
        figure_height = figure.get("layout", {}).get("height")
    else:
        # Handle Figure objects from matplotlib conversion
        try:
            figure_height = figure.layout.height
        except (AttributeError, TypeError):
            _LOGGER.debug("Could not parse height from figure")

    if (
        figure_height is not None
        and isinstance(figure_height, (int, float))
        and figure_height > 0
    ):
        return int(figure_height)

    # Default to 450 pixels (plotly.js default) when no height is specified
    return 450


class PlotlyMixin:
    @overload
    def plotly_chart(
        self,
        figure_or_data: FigureOrData,
        use_container_width: bool | None = None,
        *,
        width: Width = "stretch",
        height: Height = "content",
        theme: Literal["streamlit"] | None = "streamlit",
        key: Key | None = None,
        on_select: Literal["ignore"],  # No default value here to make it work with mypy
        selection_mode: SelectionMode | Iterable[SelectionMode] = (
            "points",
            "box",
            "lasso",
        ),
        **kwargs: Any,
    ) -> DeltaGenerator: ...

    @overload
    def plotly_chart(
        self,
        figure_or_data: FigureOrData,
        use_container_width: bool | None = None,
        *,
        width: Width = "stretch",
        height: Height = "content",
        theme: Literal["streamlit"] | None = "streamlit",
        key: Key | None = None,
        on_select: Literal["rerun"] | WidgetCallback = "rerun",
        selection_mode: SelectionMode | Iterable[SelectionMode] = (
            "points",
            "box",
            "lasso",
        ),
        **kwargs: Any,
    ) -> PlotlyState: ...

    @gather_metrics("plotly_chart")
    def plotly_chart(
        self,
        figure_or_data: FigureOrData,
        use_container_width: bool | None = None,
        *,
        width: Width = "stretch",
        height: Height = "content",
        theme: Literal["streamlit"] | None = "streamlit",
        key: Key | None = None,
        on_select: Literal["rerun", "ignore"] | WidgetCallback = "ignore",
        selection_mode: SelectionMode | Iterable[SelectionMode] = (
            "points",
            "box",
            "lasso",
        ),
        config: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> DeltaGenerator | PlotlyState:
        """Display an interactive Plotly chart.

        `Plotly <https://plot.ly/python>`_ is a charting library for Python.
        The arguments to this function closely follow the ones for Plotly's
        ``plot()`` function.

        To show Plotly charts in Streamlit, call ``st.plotly_chart`` wherever
        you would call Plotly's ``py.plot`` or ``py.iplot``.

        .. Important::
            You must install ``plotly>=4.0.0`` to use this command. Your app's
            performance may be enhanced by installing ``orjson`` as well. You
            can install all charting dependencies (except Bokeh) as an extra
            with Streamlit:

            .. code-block:: shell

               pip install streamlit[charts]

        Parameters
        ----------
        figure_or_data : plotly.graph_objs.Figure, plotly.graph_objs.Data,\
            or dict/list of plotly.graph_objs.Figure/Data

            The Plotly ``Figure`` or ``Data`` object to render. See
            https://plot.ly/python/ for examples of graph descriptions.

            .. note::
                If your chart contains more than 1000 data points, Plotly will
                use a WebGL renderer to display the chart. Different browsers
                have different limits on the number of WebGL contexts per page.
                If you have multiple WebGL contexts on a page, you may need to
                switch to SVG rendering mode. You can do this by setting
                ``render_mode="svg"`` within the figure. For example, the
                following code defines a Plotly Express line chart that will
                render in SVG mode when passed to ``st.plotly_chart``:
                ``px.line(df, x="x", y="y", render_mode="svg")``.

        width : "stretch", "content", or int
            The width of the chart element. This can be one of the following:

            - ``"stretch"`` (default): The width of the element matches the
              width of the parent container.
            - ``"content"``: The width of the element matches the width of its
              content, but doesn't exceed the width of the parent container.
            - An integer specifying the width in pixels: The element has a
              fixed width. If the specified width is greater than the width of
              the parent container, the width of the element matches the width
              of the parent container.

        height : "content", "stretch", or int
            The height of the chart element. This can be one of the following:

            - ``"content"`` (default): The height of the element matches the
              height of its content.
            - ``"stretch"``: The height of the element matches the height of
              its content or the height of the parent container, whichever is
              larger. If the element is not in a parent container, the height
              of the element matches the height of its content.
            - An integer specifying the height in pixels: The element has a
              fixed height. If the content is larger than the specified
              height, scrolling is enabled.

        use_container_width : bool or None
            Whether to override the figure's native width with the width of
            the parent container. This can be one of the following:

            - ``None`` (default): Streamlit will use the value of ``width``.
            - ``True``: Streamlit sets the width of the figure to match the
              width of the parent container.
            - ``False``: Streamlit sets the width of the figure to fit its
              contents according to the plotting library, up to the width of
              the parent container.

            .. deprecated::
               ``use_container_width`` is deprecated and will be removed in a
                future release. For ``use_container_width=True``, use
                ``width="stretch"``.

        theme : "streamlit" or None
            The theme of the chart. If ``theme`` is ``"streamlit"`` (default),
            Streamlit uses its own design default. If ``theme`` is ``None``,
            Streamlit falls back to the default behavior of the library.

            The ``"streamlit"`` theme can be partially customized through the
            configuration options ``theme.chartCategoricalColors`` and
            ``theme.chartSequentialColors``. Font configuration options are
            also applied.

        key : str
            An optional string to use for giving this element a stable
            identity. If ``key`` is ``None`` (default), this element's identity
            will be determined based on the values of the other parameters.

            Additionally, if selections are activated and ``key`` is provided,
            Streamlit will register the key in Session State to store the
            selection state. The selection state is read-only.

        on_select : "ignore" or "rerun" or callable
            How the figure should respond to user selection events. This
            controls whether or not the figure behaves like an input widget.
            ``on_select`` can be one of the following:

            - ``"ignore"`` (default): Streamlit will not react to any selection
              events in the chart. The figure will not behave like an input
              widget.

            - ``"rerun"``: Streamlit will rerun the app when the user selects
              data in the chart. In this case, ``st.plotly_chart`` will return
              the selection data as a dictionary.

            - A ``callable``: Streamlit will rerun the app and execute the
              ``callable`` as a callback function before the rest of the app.
              In this case, ``st.plotly_chart`` will return the selection data
              as a dictionary.

        selection_mode : "points", "box", "lasso" or an Iterable of these
            The selection mode of the chart. This can be one of the following:

            - ``"points"``: The chart will allow selections based on individual
              data points.
            - ``"box"``: The chart will allow selections based on rectangular
              areas.
            - ``"lasso"``: The chart will allow selections based on freeform
              areas.
            - An ``Iterable`` of the above options: The chart will allow
              selections based on the modes specified.

            All selections modes are activated by default.

        config : dict or None
            A dictionary of Plotly configuration options. This is passed to
            Plotly's ``show()`` function. For more information about Plotly
            configuration options, see Plotly's documentation on `Configuration
            in Python <https://plotly.com/python/configuration-options/>`_.

        **kwargs
            Additional arguments accepted by Plotly's ``plot()`` function.

            This supports ``config``, a dictionary of Plotly configuration
            options. For more information about Plotly configuration options,
            see Plotly's documentation on `Configuration in Python
            <https://plotly.com/python/configuration-options/>`_.

            .. deprecated::
               ``**kwargs`` are deprecated and will be removed in a future
               release. Use ``config`` instead.

        Returns
        -------
        element or dict
            If ``on_select`` is ``"ignore"`` (default), this command returns an
            internal placeholder for the chart element. Otherwise, this command
            returns a dictionary-like object that supports both key and
            attribute notation. The attributes are described by the
            ``PlotlyState`` dictionary schema.

        Examples
        --------
        **Example 1: Basic Plotly chart**

        The example below comes from the examples at https://plot.ly/python.
        Note that ``plotly.figure_factory`` requires ``scipy`` to run.

        >>> import plotly.figure_factory as ff
        >>> import streamlit as st
        >>> from numpy.random import default_rng as rng
        >>>
        >>> hist_data = [
        ...     rng(0).standard_normal(200) - 2,
        ...     rng(1).standard_normal(200),
        ...     rng(2).standard_normal(200) + 2,
        ... ]
        >>> group_labels = ["Group 1", "Group 2", "Group 3"]
        >>>
        >>> fig = ff.create_distplot(
        ...     hist_data, group_labels, bin_size=[0.1, 0.25, 0.5]
        ... )
        >>>
        >>> st.plotly_chart(fig)

        .. output::
           https://doc-plotly-chart.streamlit.app/
           height: 550px

        **Example 2: Plotly Chart with configuration**

        By default, Plotly charts have scroll zoom enabled. If you have a
        longer page and want to avoid conflicts between page scrolling and
        zooming, you can use Plotly's configuration options to disable scroll
        zoom. In the following example, scroll zoom is disabled, but the zoom
        buttons are still enabled in the modebar.

        >>> import plotly.graph_objects as go
        >>> import streamlit as st
        >>>
        >>> fig = go.Figure()
        >>> fig.add_trace(
        ...     go.Scatter(
        ...         x=[1, 2, 3, 4, 5],
        ...         y=[1, 3, 2, 5, 4]
        ...     )
        ... )
        >>>
        >>> st.plotly_chart(fig, config = {'scrollZoom': False})

        .. output::
           https://doc-plotly-chart-config.streamlit.app/
           height: 550px

        """
        if use_container_width is not None:
            show_deprecation_warning(
                make_deprecated_name_warning(
                    "use_container_width",
                    "width",
                    "2025-12-31",
                    "For `use_container_width=True`, use `width='stretch'`. "
                    "For `use_container_width=False`, use `width='content'`.",
                    include_st_prefix=False,
                ),
                show_in_browser=False,
            )
            if use_container_width:
                width = "stretch"
            elif not isinstance(width, int):
                width = "content"

        validate_width(width, allow_content=True)
        validate_height(height, allow_content=True)

        import plotly.io
        import plotly.tools

        # NOTE: "figure_or_data" is the name used in Plotly's .plot() method
        # for their main parameter. I don't like the name, but it's best to
        # keep it in sync with what Plotly calls it.

        if kwargs:
            show_deprecation_warning(
                "Variable keyword arguments for `st.plotly_chart` have been "
                "deprecated and will be removed in a future release. Use the "
                "`config` argument instead to specify Plotly configuration "
                "options."
            )

        if theme not in ["streamlit", None]:
            raise StreamlitAPIException(
                f'You set theme="{theme}" while Streamlit charts only support '
                "theme=”streamlit” or theme=None to fallback to the default "
                "library theme."
            )

        if on_select not in ["ignore", "rerun"] and not callable(on_select):
            raise StreamlitAPIException(
                f"You have passed {on_select} to `on_select`. But only 'ignore', "
                "'rerun', or a callable is supported."
            )

        key = to_key(key)
        is_selection_activated = on_select != "ignore"

        if is_selection_activated:
            # Run some checks that are only relevant when selections are activated

            is_callback = callable(on_select)
            check_widget_policies(
                self.dg,
                key,
                on_change=cast("WidgetCallback", on_select) if is_callback else None,
                default_value=None,
                writes_allowed=False,
                enable_check_callback_rules=is_callback,
            )

        if type_util.is_type(figure_or_data, "matplotlib.figure.Figure"):
            # Convert matplotlib figure to plotly figure:
            figure = plotly.tools.mpl_to_plotly(figure_or_data)
        else:
            figure = plotly.tools.return_figure_from_figure_or_data(
                figure_or_data, validate_figure=True
            )

        plotly_chart_proto = PlotlyChartProto()
        plotly_chart_proto.theme = theme or ""
        plotly_chart_proto.form_id = current_form_id(self.dg)

        config = config or {}
        plotly_chart_proto.spec = plotly.io.to_json(figure, validate=False)
        plotly_chart_proto.config = json.dumps(config)

        ctx = get_script_run_ctx()

        # We are computing the widget id for all plotly uses
        # to also allow non-widget Plotly charts to keep their state
        # when the frontend component gets unmounted and remounted.
        plotly_chart_proto.id = compute_and_register_element_id(
            "plotly_chart",
            user_key=key,
            key_as_main_identity=False,
            dg=self.dg,
            plotly_spec=plotly_chart_proto.spec,
            plotly_config=plotly_chart_proto.config,
            selection_mode=selection_mode,
            is_selection_activated=is_selection_activated,
            theme=theme,
            width=width,
            height=height,
        )

        # Handle "content" width and height by inspecting the figure's natural dimensions
        final_width = _resolve_content_width(width, figure)
        final_height = _resolve_content_height(height, figure)

        if is_selection_activated:
            # Selections are activated, treat plotly chart as a widget:
            plotly_chart_proto.selection_mode.extend(
                parse_selection_mode(selection_mode)
            )

            serde = PlotlyChartSelectionSerde()

            widget_state = register_widget(
                plotly_chart_proto.id,
                on_change_handler=on_select if callable(on_select) else None,
                deserializer=serde.deserialize,
                serializer=serde.serialize,
                ctx=ctx,
                value_type="string_value",
            )

            layout_config = LayoutConfig(width=final_width, height=final_height)
            self.dg._enqueue(
                "plotly_chart", plotly_chart_proto, layout_config=layout_config
            )
            return widget_state.value

        layout_config = LayoutConfig(width=final_width, height=final_height)
        return self.dg._enqueue(
            "plotly_chart", plotly_chart_proto, layout_config=layout_config
        )

    @property
    def dg(self) -> DeltaGenerator:
        """Get our DeltaGenerator."""
        return cast("DeltaGenerator", self)
