from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast

import polars as pl

from narwhals._polars.utils import (
    BACKEND_VERSION,
    PolarsAnyNamespace,
    PolarsCatNamespace,
    PolarsDateTimeNamespace,
    PolarsListNamespace,
    PolarsStringNamespace,
    PolarsStructNamespace,
    extract_args_kwargs,
    extract_native,
    narwhals_to_native_dtype,
)
from narwhals._utils import Implementation, no_default, requires

if TYPE_CHECKING:
    from collections.abc import Sequence

    from typing_extensions import Self

    from narwhals._compliant.typing import Accessor
    from narwhals._expression_parsing import ExprMetadata
    from narwhals._polars.dataframe import Method
    from narwhals._polars.namespace import PolarsNamespace
    from narwhals._polars.series import PolarsSeries
    from narwhals._typing import NoDefault
    from narwhals._utils import Version
    from narwhals.typing import IntoDType, ModeKeepStrategy


class PolarsExpr:
    # CompliantExpr
    _implementation: Implementation = Implementation.POLARS
    _version: Version
    _native_expr: pl.Expr
    _evaluate_output_names: Any
    _alias_output_names: Any
    __call__: Any

    @classmethod
    def _from_series(cls, series: PolarsSeries) -> Self:
        return cls(series.native, version=series._version)  # type: ignore[arg-type]

    # CompliantExpr + builtin descriptor
    # TODO @dangotbanned: Remove in #2713
    @classmethod
    def from_column_names(cls, *_: Any, **__: Any) -> Self:
        raise NotImplementedError

    @classmethod
    def from_column_indices(cls, *_: Any, **__: Any) -> Self:
        raise NotImplementedError

    def __narwhals_expr__(self) -> Self:  # pragma: no cover
        return self

    def __narwhals_namespace__(self) -> PolarsNamespace:  # pragma: no cover
        from narwhals._polars.namespace import PolarsNamespace

        return PolarsNamespace(version=self._version)

    def __init__(self, expr: pl.Expr, version: Version) -> None:
        self._native_expr = expr
        self._version = version

    @property
    def _backend_version(self) -> tuple[int, ...]:
        return self._implementation._backend_version()

    @property
    def native(self) -> pl.Expr:
        return self._native_expr

    def __repr__(self) -> str:  # pragma: no cover
        return "PolarsExpr"

    def _with_native(self, expr: pl.Expr) -> Self:
        return self.__class__(expr, self._version)

    def broadcast(self) -> Self:
        # Let Polars do its thing.
        return self

    @property
    def _metadata(self) -> ExprMetadata:
        assert self._opt_metadata is not None  # noqa: S101
        return cast("ExprMetadata", self._opt_metadata)

    def __getattr__(self, attr: str) -> Any:
        def func(*args: Any, **kwargs: Any) -> Any:
            pos, kwds = extract_args_kwargs(args, kwargs)
            return self._with_native(getattr(self.native, attr)(*pos, **kwds))

        return func

    def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
        name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
        return {name: min_samples}

    def cast(self, dtype: IntoDType) -> Self:
        dtype_pl = narwhals_to_native_dtype(dtype, self._version)
        return self._with_native(self.native.cast(dtype_pl))

    def clip_lower(self, lower_bound: PolarsExpr) -> Self:
        lower_native = extract_native(lower_bound)
        return self._with_native(self.native.clip(lower_native))

    def clip_upper(self, upper_bound: PolarsExpr) -> Self:
        upper_native = extract_native(upper_bound)
        return self._with_native(self.native.clip(None, upper_native))

    def ewm_mean(
        self,
        *,
        com: float | None,
        span: float | None,
        half_life: float | None,
        alpha: float | None,
        adjust: bool,
        min_samples: int,
        ignore_nulls: bool,
    ) -> Self:
        native = self.native.ewm_mean(
            com=com,
            span=span,
            half_life=half_life,
            alpha=alpha,
            adjust=adjust,
            ignore_nulls=ignore_nulls,
            **self._renamed_min_periods(min_samples),
        )
        if self._backend_version < (1,):  # pragma: no cover
            native = pl.when(~self.native.is_null()).then(native).otherwise(None)
        return self._with_native(native)

    def is_nan(self) -> Self:
        if self._backend_version >= (1, 18):
            native = self.native.is_nan()
        else:  # pragma: no cover
            native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
        return self._with_native(native)

    def is_finite(self) -> Self:
        if self._backend_version >= (1, 18):
            native = self.native.is_finite()
        else:  # pragma: no cover
            native = pl.when(self.native.is_not_null()).then(self.native.is_finite())
        return self._with_native(native)

    def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
        # Use `pl.repeat(1, pl.len())` instead of `pl.lit(1)` to avoid issues for
        # non-numeric types: https://github.com/pola-rs/polars/issues/24756.
        pl_partition_by = partition_by or pl.repeat(1, pl.len())
        if self._backend_version < (1, 9):
            if order_by:
                msg = "`order_by` in Polars requires version 1.10 or greater"
                raise NotImplementedError(msg)
            native = self.native.over(pl_partition_by)
        else:
            native = self.native.over(pl_partition_by, order_by=order_by or None)
        return self._with_native(native)

    @requires.backend_version((1,))
    def rolling_var(
        self, window_size: int, *, min_samples: int, center: bool, ddof: int
    ) -> Self:
        kwds = self._renamed_min_periods(min_samples)
        native = self.native.rolling_var(
            window_size=window_size, center=center, ddof=ddof, **kwds
        )
        return self._with_native(native)

    @requires.backend_version((1,))
    def rolling_std(
        self, window_size: int, *, min_samples: int, center: bool, ddof: int
    ) -> Self:
        kwds = self._renamed_min_periods(min_samples)
        native = self.native.rolling_std(
            window_size=window_size, center=center, ddof=ddof, **kwds
        )
        return self._with_native(native)

    def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
        kwds = self._renamed_min_periods(min_samples)
        native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
        return self._with_native(native)

    def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
        kwds = self._renamed_min_periods(min_samples)
        native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
        return self._with_native(native)

    def map_batches(
        self,
        function: Callable[[Any], Any],
        return_dtype: IntoDType | None,
        *,
        returns_scalar: bool,
    ) -> Self:
        pl_version = self._backend_version
        return_dtype_pl = (
            narwhals_to_native_dtype(return_dtype, self._version)
            if return_dtype is not None
            else None
            if pl_version < (1, 32)
            else pl.self_dtype()
        )
        kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar}
        native = self.native.map_batches(function, return_dtype_pl, **kwargs)
        return self._with_native(native)

    @requires.backend_version((1,))
    def replace_strict(
        self,
        default: PolarsExpr | NoDefault,
        old: Sequence[Any],
        new: Sequence[Any],
        *,
        return_dtype: IntoDType | None,
    ) -> Self:
        return_dtype_pl = (
            narwhals_to_native_dtype(return_dtype, self._version)
            if return_dtype
            else None
        )
        extra_kwargs = (
            {} if default is no_default else {"default": extract_native(default)}
        )
        native = self.native.replace_strict(
            old, new, return_dtype=return_dtype_pl, **extra_kwargs
        )
        return self._with_native(native)

    def __eq__(self, other: PolarsExpr) -> Self:  # type: ignore[override]
        return self._with_native(self.native.__eq__(extract_native(other)))

    def __ne__(self, other: PolarsExpr) -> Self:  # type: ignore[override]
        return self._with_native(self.native.__ne__(extract_native(other)))

    def __ge__(self, other: Any) -> Self:
        return self._with_native(self.native.__ge__(extract_native(other)))

    def __gt__(self, other: Any) -> Self:
        return self._with_native(self.native.__gt__(extract_native(other)))

    def __le__(self, other: Any) -> Self:
        return self._with_native(self.native.__le__(extract_native(other)))

    def __lt__(self, other: Any) -> Self:
        return self._with_native(self.native.__lt__(extract_native(other)))

    def __and__(self, other: PolarsExpr) -> Self:
        return self._with_native(self.native.__and__(extract_native(other)))

    def __or__(self, other: PolarsExpr) -> Self:
        return self._with_native(self.native.__or__(extract_native(other)))

    def __add__(self, other: Any) -> Self:
        return self._with_native(self.native.__add__(extract_native(other)))

    def __sub__(self, other: Any) -> Self:
        return self._with_native(self.native.__sub__(extract_native(other)))

    def __mul__(self, other: Any) -> Self:
        return self._with_native(self.native.__mul__(extract_native(other)))

    def __pow__(self, other: Any) -> Self:
        return self._with_native(self.native.__pow__(extract_native(other)))

    def __truediv__(self, other: Any) -> Self:
        return self._with_native(self.native.__truediv__(extract_native(other)))

    def __floordiv__(self, other: Any) -> Self:
        return self._with_native(self.native.__floordiv__(extract_native(other)))

    def __rfloordiv__(self, other: Any) -> Self:
        native = self.native
        result = native.__rfloordiv__(extract_native(other))
        if self._backend_version < (1, 10, 0):
            # Polars 1.9.0 and earlier returns 0 for division by 0 in rfloordiv.
            result = pl.when(native != 0).then(result).otherwise(None)
        return self._with_native(result)

    def __mod__(self, other: Any) -> Self:
        return self._with_native(self.native.__mod__(extract_native(other)))

    def __invert__(self) -> Self:
        return self._with_native(self.native.__invert__())

    def cum_count(self, *, reverse: bool) -> Self:
        return self._with_native(self.native.cum_count(reverse=reverse))

    def mode(self, *, keep: ModeKeepStrategy) -> Self:
        result = self.native.mode()
        return self._with_native(result.first() if keep == "any" else result)

    def any_value(self, *, ignore_nulls: bool) -> Self:
        return self.drop_nulls().first() if ignore_nulls else self.first()

    @property
    def dt(self) -> PolarsExprDateTimeNamespace:
        return PolarsExprDateTimeNamespace(self)

    @property
    def str(self) -> PolarsExprStringNamespace:
        return PolarsExprStringNamespace(self)

    @property
    def cat(self) -> PolarsExprCatNamespace:
        return PolarsExprCatNamespace(self)

    @property
    def name(self) -> PolarsExprNameNamespace:
        return PolarsExprNameNamespace(self)

    @property
    def list(self) -> PolarsExprListNamespace:
        return PolarsExprListNamespace(self)

    @property
    def struct(self) -> PolarsExprStructNamespace:
        return PolarsExprStructNamespace(self)

    # Polars
    abs: Method[Self]
    all: Method[Self]
    any: Method[Self]
    alias: Method[Self]
    arg_max: Method[Self]
    arg_min: Method[Self]
    arg_true: Method[Self]
    ceil: Method[Self]
    count: Method[Self]
    cum_max: Method[Self]
    cum_min: Method[Self]
    cum_prod: Method[Self]
    cum_sum: Method[Self]
    diff: Method[Self]
    drop_nulls: Method[Self]
    exp: Method[Self]
    fill_null: Method[Self]
    fill_nan: Method[Self]
    first: Method[Self]
    floor: Method[Self]
    last: Method[Self]
    gather_every: Method[Self]
    head: Method[Self]
    is_between: Method[Self]
    is_duplicated: Method[Self]
    is_first_distinct: Method[Self]
    is_in: Method[Self]
    is_last_distinct: Method[Self]
    is_null: Method[Self]
    is_unique: Method[Self]
    kurtosis: Method[Self]
    len: Method[Self]
    log: Method[Self]
    max: Method[Self]
    mean: Method[Self]
    median: Method[Self]
    min: Method[Self]
    n_unique: Method[Self]
    null_count: Method[Self]
    quantile: Method[Self]
    rank: Method[Self]
    round: Method[Self]
    sample: Method[Self]
    shift: Method[Self]
    sin: Method[Self]
    skew: Method[Self]
    sqrt: Method[Self]
    std: Method[Self]
    sum: Method[Self]
    sort: Method[Self]
    tail: Method[Self]
    unique: Method[Self]
    var: Method[Self]
    __rsub__: Method[Self]
    __rmod__: Method[Self]
    __rpow__: Method[Self]
    __rtruediv__: Method[Self]


class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
    def __init__(self, expr: PolarsExpr) -> None:
        self._expr = expr

    @property
    def compliant(self) -> PolarsExpr:
        return self._expr

    @property
    def native(self) -> pl.Expr:
        return self._expr.native


class PolarsExprDateTimeNamespace(
    PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
): ...


class PolarsExprStringNamespace(
    PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
):
    def to_titlecase(self) -> PolarsExpr:
        native_expr = self.native

        if BACKEND_VERSION < (1, 35):
            native_result = (
                native_expr.str.to_lowercase()
                .str.extract_all(r"[a-z]*[^a-z]*")
                .list.eval(pl.element().str.to_titlecase())
                .list.join("")
            )
        else:  # pragma: no cover
            native_result = native_expr.str.to_titlecase()

        return self.compliant._with_native(native_result)

    @requires.backend_version((0, 20, 5))
    def zfill(self, width: int) -> PolarsExpr:
        backend_version = self.compliant._backend_version
        native_result = self.native.str.zfill(width)

        if backend_version <= (1, 30, 0):
            length = self.native.str.len_chars()
            less_than_width = length < width
            plus = "+"
            starts_with_plus = self.native.str.starts_with(plus)
            native_result = (
                pl.when(starts_with_plus & less_than_width)
                .then(
                    self.native.str.slice(1, length)
                    .str.zfill(width - 1)
                    .str.pad_start(width, plus)
                )
                .otherwise(native_result)
            )

        return self.compliant._with_native(native_result)

    def replace(
        self, value: PolarsExpr, pattern: str, *, literal: bool, n: int
    ) -> PolarsExpr:
        value_native = extract_native(value)
        return self.compliant._with_native(
            self.native.str.replace(pattern, value_native, literal=literal, n=n)
        )

    def replace_all(
        self, value: PolarsExpr, pattern: str, *, literal: bool
    ) -> PolarsExpr:
        value_native = extract_native(value)
        return self.compliant._with_native(
            self.native.str.replace_all(pattern, value_native, literal=literal)
        )


class PolarsExprCatNamespace(
    PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
): ...


class PolarsExprNameNamespace(PolarsExprNamespace):
    _accessor: ClassVar[Accessor] = "name"
    keep: Method[PolarsExpr]
    map: Method[PolarsExpr]
    prefix: Method[PolarsExpr]
    suffix: Method[PolarsExpr]
    to_lowercase: Method[PolarsExpr]
    to_uppercase: Method[PolarsExpr]


class PolarsExprListNamespace(
    PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
):
    def len(self) -> PolarsExpr:
        native_expr = self.native
        native_result = native_expr.list.len()

        if self.compliant._backend_version < (1, 16):  # pragma: no cover
            native_result = (
                pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
            )
        elif self.compliant._backend_version < (1, 17):  # pragma: no cover
            native_result = native_result.cast(pl.UInt32())

        return self.compliant._with_native(native_result)

    def contains(self, item: Any) -> PolarsExpr:
        if self.compliant._backend_version < (1, 28):
            result: pl.Expr = pl.when(self.native.is_not_null()).then(
                self.native.list.contains(item)
            )
        else:
            result = self.native.list.contains(item)
        return self.compliant._with_native(result)


class PolarsExprStructNamespace(
    PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
): ...
