#!/usr/bin/env python3
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import base64
import itertools
import sys
import uuid
from collections import namedtuple, deque
from typing import Optional, List, Iterable, Tuple

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

from snowflake.snowpark.version import VERSION

# The current AST version number (generated by the DSL).
CLIENT_AST_VERSION = proto.__Version__.MAX_VERSION


SerializedBatch = namedtuple("SerializedBatch", ["request_id", "batch"])


class AstBuilder:
    """
    A handler for AST statements that enables a Snowpark object (e.g. Dataframe, Table, etc.) to carry its own AST,
    and to generate AST requests.

    The core statement types are:
    - Bind: Creates a new variable and assigns a value to it.
    - Eval: Evaluates a variable.
    """

    # NOTE: This class is not thread-safe.

    # Function used to generate request IDs. This is overridden in some tests.
    __generate_request_id = uuid.uuid4

    # Class variable for generating IDs for bind statements.
    # NOTE: itertools.count and its __next__ method are thread-safe and atomic in CPython (with GIL).
    __id_gen = itertools.count(start=1)

    def __init__(
        self,
    ) -> None:
        """
        Initializes a new AST batch.
        """
        self._bind: proto.Bind = None
        self.__dependencies: List[AstBuilder] = []
        self.__required_by: List[AstBuilder] = []

    @classmethod
    def bind(cls, symbol: Optional[str] = None) -> "AstBuilder":
        """
        Factory method to create an AstBuilder instance with a new bind statement.

        Args:
            symbol: An optional symbol to name the new variable in the bind statement.
        Returns:
            An instance of AstBuilder with a bind statement initialized.
        """
        b = cls()
        b._bind = proto.Bind()
        b._bind.uid = next(AstBuilder.__id_gen)
        b._bind.symbol.value = symbol if isinstance(symbol, str) else ""
        return b

    def depends_on(self, dependency: "AstBuilder") -> None:
        """
        Adds a dependency to the current AstBuilder instance.
        This allows for tracking dependencies between different AST statements.

        Args:
            dependency: An AstBuilder instance that is a dependency of the current instance.
        """
        if dependency not in self.__dependencies:
            self.__dependencies.append(dependency)
        if self not in dependency.__required_by:
            dependency.__required_by.append(self)

    def _closure(self) -> Iterable["AstBuilder"]:
        """
        Performs a post-order traversal of the AST to compute the transitive closure of dependencies.
        This ensures that all dependencies are included in the final AST request.

        Args:
            root: The root AstBuilder instance from which to compute the closure.
        """
        q: deque[AstBuilder] = deque()
        q.append(self)

        visited = set()
        closure: deque[AstBuilder] = deque()

        while q:
            ast_builder = q.pop()
            closure.appendleft(ast_builder)
            visited.add(ast_builder)
            for dep in ast_builder.__dependencies:
                if dep not in visited:
                    q.append(dep)

        return closure

    def to_request(
        self, eval: bool = False, assign_first_request_ids: bool = True
    ) -> Tuple[uuid.UUID, proto.Request]:
        """
        Builds and returns a new AST request and ID with transitive closure of all dependencies.
        """
        request = proto.Request()

        # Generate a new unique ID and convert to bytes for the request message field.
        current_request_id: uuid.UUID = AstBuilder.__generate_request_id()
        request.id = current_request_id.bytes

        # Initialize the request with the current client version, Python version, and AST version.
        (major, minor, patch) = VERSION
        request.client_version.major = major
        request.client_version.minor = minor
        request.client_version.patch = patch

        (major, minor, micro, releaselevel, serial) = sys.version_info
        request.client_language.python_language.version.major = major
        request.client_language.python_language.version.minor = minor
        request.client_language.python_language.version.patch = micro
        request.client_language.python_language.version.label = releaselevel

        request.client_ast_version = CLIENT_AST_VERSION

        # Get the transitive closure of the current AST builder and its dependencies.
        # Add all statements in the closure to the request body with unique IDs.
        closure = self._closure()
        for ast_handle in closure:
            if assign_first_request_ids and not ast_handle._bind.first_request_id:
                # Assign the first request IDs for bind statements if not already set.
                ast_handle._bind.first_request_id = current_request_id.bytes

            # Add the bind and eval statements to the request body.
            request.body.append(ast_handle._bind)

        if eval:
            # Generate the eval statement for the current AstBuilder instance's Bind statement.
            stmt = request.body.add()
            stmt.eval.bind_id = self._bind.uid

        from snowflake.snowpark._internal.ast.utils import fill_interned_value_table

        fill_interned_value_table(request.interned_value_table)
        return current_request_id, request

    def to_batch(
        self, eval: bool = False, assign_first_request_ids: bool = True
    ) -> SerializedBatch:
        """
        Builds and returns a serialized Request UUID and AST Request message with the transitive closure of all dependencies.
        """
        request_id, request = self.to_request(eval, assign_first_request_ids)
        # Serialize the request to a base64-encoded string for transmission.
        batch = str(base64.b64encode(request.SerializeToString()), "utf-8")
        return SerializedBatch(str(request_id), batch)
