#!/usr/bin/env python3
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import sys
import base64
import itertools
import threading
import uuid
import heapq
from collections import namedtuple
from dataclasses import dataclass
from typing import Callable, Optional, Any, Iterable

from google.protobuf.message import Message
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

from snowflake.snowpark.version import VERSION

# TODO(SNOW-1791994): Enable pyright type checks for this file.


# The current AST version number (generated by the DSL).
CLIENT_AST_VERSION = proto.__Version__.MAX_VERSION

# All Snowpark AST entities (protobuf message types) which refer to other AST entities via their "id: int" field.
REF_DESCRIPTOR_NAMES = {
    "DataframeRef",
    "RelationalGroupedDataframeRef",
    "FnIdRefExpr",
    "FnRef",
    "IndirectTableFnIdRef",
}


def get_dependent_bind_ids(ast: Any) -> set[int]:
    """Retrieve the dependent AST IDs required for this AST object."""
    dependent_ids = set()
    if isinstance(ast, Iterable) and not isinstance(ast, (str, bytes, bytearray)):
        for c in ast:
            dependent_ids.update(get_dependent_bind_ids(c))
    elif hasattr(ast, "DESCRIPTOR"):
        descriptor = ast.DESCRIPTOR
        if descriptor.name in REF_DESCRIPTOR_NAMES:
            dependent_ids.add(ast.id)  # type: ignore[union-attr]
        else:
            for f in descriptor.fields:
                c = getattr(ast, f.name)
                if (
                    isinstance(c, Iterable)
                    and not isinstance(c, (str, bytes, bytearray))
                    and len(c) > 0  # type: ignore[arg-type]
                ) or (isinstance(c, Message) and c.ByteSize() > 0):
                    dependent_ids.update(get_dependent_bind_ids(c))
    return dependent_ids


@dataclass
class TrackedCallable:
    """
    Several Snowpark APIs that deal with stored procedures and user-defined functions accept callables as arguments.
    This class is a pair of a callable and an ID that is used to reference it in the AST. Distinct objects get distinct IDs.
    It is undesirable for the same callable to have multiple IDs due to constraints in other parts of the system.
    """

    bind_id: int
    func: Callable


SerializedBatch = namedtuple("SerializedBatch", ["request_id", "batch"])


# AstBatch is not thread safe by itself, but is thread compatible. All access to AstBatch should be synchronized through
# external means.
class AstBatch:
    """
    A batch of AST statements. This class is used to generate AST requests.

    The core statement types are:
    - Bind: Creates a new variable and assigns a value to it.
    - Eval: Evaluates a variable.
    """

    # Function used to generate request IDs. This is overridden in some tests.
    generate_request_id = uuid.uuid4

    # Class variable for generating globally unique IDs (within a session) for Bind statements.
    # NOTE: itertools.count and its __next__ method are thread-safe and atomic in CPython (with GIL).
    __id_gen = itertools.count(start=1)

    def __init__(
        self,
        session: "snowflake.snowpark.Session",  # type: ignore[name-defined]  # noqa: F821
    ) -> None:
        """
        Initializes a new AST batch.

        Args:
            session: The Snowpark session.
        """
        self._session = session
        self._lock = threading.RLock()

        self._init_batch()

        # Track callables in this dict (memory id -> TrackedCallable).
        self._callables: dict[int, TrackedCallable] = {}

        # Track all generated Bind statements by their UIDs.
        self._bind_stmt_cache: dict[int, proto.Stmt] = {}

        # Cache the dependencies of each Bind statement.
        self._dependency_cache: dict[int, set[int]] = {}

    def _init_batch(self) -> None:
        """
        Reset the AST batch by initializing a new request ID and clearing the current statements.
        """
        with self._lock:
            # Generate a new unique ID.
            self._request_id = AstBatch.generate_request_id()
            # Maintain a priority queue of Bind IDs generated or referenced for the current request.
            self._cur_request_bind_id_q: list[int] = []
            # Maintain a set of Bind IDs generated or referenced in the current request.
            self._cur_request_bind_ids: set[int] = set()
            # Maintain a set of Bind IDs referenced by Eval statements in the current request.
            self._eval_ids: set[int] = set()

    def _reset_id_gen(self) -> None:
        """
        THIS METHOD IS FOR TESTING PURPOSES ONLY. DO NOT USE IN PRODUCTION CODE.
        """
        with self._lock:
            AstBatch.__id_gen = itertools.count(start=1)

    def bind(self, symbol: Optional[str] = None) -> proto.Bind:
        """
        Creates a new Bind statement.

        Args:
            symbol: An optional symbol to name the new variable.
        """
        stmt = proto.Stmt()
        stmt.bind.symbol.value = symbol if isinstance(symbol, str) else ""
        with self._lock:
            stmt.bind.uid = next(AstBatch.__id_gen)
            stmt.bind.first_request_id = self._request_id.bytes
            heapq.heappush(self._cur_request_bind_id_q, stmt.bind.uid)
            self._cur_request_bind_ids.add(stmt.bind.uid)
            self._bind_stmt_cache[stmt.bind.uid] = stmt
        return stmt.bind

    def eval(self, target: proto.Bind) -> None:
        """
        Adds the ID of the target Bind statement to the current request.

        Args:
            target: The variable to evaluate.
        """
        with self._lock:
            self._eval_ids.add(target.uid)
            # If the target Bind ID of the Eval statement is not in the current requests Bind ID list,
            # add it to the list. Note that we have an invariance that if the Eval statement refers to a
            # Bind ID not in self._cur_request_bind_ids, it must not be in self._cur_bind_stmts either.
            if target.uid not in self._cur_request_bind_ids:
                heapq.heappush(self._cur_request_bind_id_q, target.uid)
                self._cur_request_bind_ids.add(target.uid)

    def cur_stmts_closure(self) -> list[proto.Stmt]:
        """
        Computes the transitive closure of the current request.

        Returns:
            - full_request_bind_ids: A list of all Bind IDs in the current request and their dependencies.
                Note that this is built as a min heap with heapq for ordered retrieval via heapq.heappop.
        """
        with self._lock:
            # Keep track of all the Bind IDs (including dependencies) that we need to send over.
            # This list should be maintained as min heap to ensure that we unparse Bind statements in the order they were generated.
            full_request_bind_ids: list[int] = []

            # Keep track of Bind IDs we have already visited.
            visited_bind_ids = set()

            # Priority queue to process the Bind statements in the order they were generated.
            queue_bind_ids = list(self._cur_request_bind_id_q)
            while queue_bind_ids:
                # queue_bind_ids is a copy of self._cur_request_bind_ids, which is maintained as a min heap.
                # We pop the minimum Bind ID from the queue with heappop to process the earliest generated Bind statement.
                bind_id = heapq.heappop(queue_bind_ids)
                visited_bind_ids.add(bind_id)
                # We use a heap to maintain the order of Bind IDs across multiple requests.
                heapq.heappush(full_request_bind_ids, bind_id)

                # For a previously seen Bind ID, its dependencies should have been cached.
                # Otherwise, the Bind ID must be in the current request, so we need to compute its dependencies.
                if bind_id in self._dependency_cache:
                    dependent_bind_ids = self._dependency_cache[bind_id]
                else:
                    dependent_bind_ids = get_dependent_bind_ids(
                        self._bind_stmt_cache[bind_id]
                    )
                    self._dependency_cache[bind_id] = dependent_bind_ids

                # Add the new dependent Bind IDs to the queue if we have not visited them yet.
                new_dependent_bind_ids = dependent_bind_ids.difference(visited_bind_ids)
                # Maintain the heap invariant by pushing new Bind IDs to the queue one by one.
                for id in new_dependent_bind_ids:
                    heapq.heappush(queue_bind_ids, id)

            return full_request_bind_ids

    def to_request(
        self,
    ) -> proto.Request:
        """Create fully contained AST request with all dependent AST objects."""

        # Create new request to send the batch of statements and their dependencies.
        request = proto.Request()

        # Set the client version and language.
        (major, minor, patch) = VERSION
        request.client_version.major = major
        request.client_version.minor = minor
        request.client_version.patch = patch

        # Set the Python version.
        (major, minor, micro, releaselevel, serial) = sys.version_info
        request.client_language.python_language.version.major = major
        request.client_language.python_language.version.minor = minor
        request.client_language.python_language.version.patch = micro
        request.client_language.python_language.version.label = releaselevel

        # Set the AST version.
        request.client_ast_version = CLIENT_AST_VERSION

        with self._lock:
            # Convert UUID to bytes for the request.
            request.id = self._request_id.bytes
            # Compute the transitive closure of all statement IDs in self._cur_request_bind_ids.
            full_request_bind_ids = self.cur_stmts_closure()

            # Add all the Bind and Eval statements to the request body.
            while full_request_bind_ids:
                bind_id = heapq.heappop(full_request_bind_ids)
                # Add the Bind statement to the request.
                request.body.append(self._bind_stmt_cache[bind_id])
                # Add the Eval statement for the current Bind ID if it exists.
                if bind_id in self._eval_ids:
                    stmt = request.body.add()
                    stmt.eval.bind_id = bind_id

        return request

    def flush(
        self,
        target: Optional[proto.Bind] = None,
    ) -> SerializedBatch:
        """Ties off a batch and starts a new one. Returns the tied-off batch."""
        with self._lock:
            # If the target is not None, add the target Bind ID to the current request.
            # Handles race condition due to lock release between calls to flush and eval.
            if target and target.uid not in self._eval_ids:
                self.eval(target)
            # Get the current request ID and batch before resetting the batch.
            req_id = str(self._request_id)
            request = self.to_request()
            # Reset the current AstBatch instance for the next request.
            self._init_batch()

            # Only filenames are interned, flush the lookup table as part of the request.
            from snowflake.snowpark._internal.ast.utils import fill_interned_value_table

            fill_interned_value_table(request.interned_value_table)

            batch = str(base64.b64encode(request.SerializeToString()), "utf-8")
            return SerializedBatch(req_id, batch)

    # TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
    def register_callable(self, func: Callable) -> int:  # pragma: no cover
        """Tracks client-side an actual callable and returns an ID."""
        with self._lock:
            k = id(func)

            if k in self._callables.keys():
                return self._callables[k].bind_id

            next_id = len(self._callables)
            self._callables[k] = TrackedCallable(bind_id=next_id, func=func)
            return next_id

    def clear(self) -> None:
        """
        Clears the current instance of AstBatch of the following:
        - The current request's:
            - Request ID.
            - Bind statement queue and set of Bind IDs.
            - Eval statement Bind ID set.
        - The cache of callable objects.
        - The cache of Bind statements.
        - The cache of dependencies between Bind statements.
        """
        with self._lock:
            self._init_batch()
            self._callables.clear()
            self._bind_stmt_cache.clear()
            self._dependency_cache.clear()
