import contextlib
import dataclasses
import functools
import itertools
import logging
import operator
import re
from itertools import chain
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    TYPE_CHECKING,
    Union,
)

import sympy
from sympy.printing.printer import Printer

import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.value_ranges import ValueRanges

from .. import config, metrics
from ..utils import (
    DeferredLineBase,
    do_bench,
    free_symbol_startswith,
    IndentedBuffer,
    sympy_dot,
    sympy_index_symbol,
    sympy_subs,
    unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V

if TYPE_CHECKING:
    from ..ir import TensorBox

schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")


def data_type_logger(msg):
    if schedule_log.isEnabledFor(logging.DEBUG):
        schedule_log.debug("Data type propagation: %s", msg)


@dataclasses.dataclass
class WorkspaceArg:
    """A temporary buffer used for a single kernel, then discarded.

    Not registered as a traditional buffer since there are no users,
    so it would be dead code eliminated.
    """

    nbytes: sympy.Expr
    zero_fill: bool


@dataclasses.dataclass
class TensorArg:
    name: str
    buffer: str
    dtype: torch.dtype
    offset: sympy.Expr = sympy.Integer(0)


@dataclasses.dataclass
class SizeArg:
    name: str
    expr: sympy.Expr


@dataclasses.dataclass
class DeviceCodegen:
    scheduling: type
    wrapper_codegen: type


KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]

device_codegens: Dict[str, DeviceCodegen] = {}


class DeviceOpOverrides:
    def import_get_raw_stream_as(self, name):
        raise NotImplementedError()

    def set_device(self, device_idx):
        raise NotImplementedError()

    def synchronize(self):
        raise NotImplementedError()

    def device_guard(self, device_idx):
        raise NotImplementedError()


device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}


# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
    device: str, device_scheduling: type, device_wrapper_codegen: type
):
    device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)


def get_scheduling_for_device(device: str):
    return device_codegens[device].scheduling if device in device_codegens else None


def get_wrapper_codegen_for_device(device: str):
    return (
        device_codegens[device].wrapper_codegen if device in device_codegens else None
    )


def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
    from ..ir import FlexibleLayout

    # added contiguous index prevents reordering
    return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]


def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
    device_op_overrides_dict[device] = device_op_overrides


def get_device_op_overrides(device: str):
    assert isinstance(device, str)

    if not device_op_overrides_dict.keys():
        from .cuda import device_op_overrides  # noqa: F401

    if device in device_op_overrides_dict.keys():
        return device_op_overrides_dict[device]

    return DeviceOpOverrides()


@functools.lru_cache(None)
def boolean_ops():
    return (
        "is_inf",
        "is_nan",
        "bitwise_xor",
        "logical_not",
        "signbit",
        "le",
        "lt",
        "ge",
        "gt",
        "eq",
        "ne",
    )


DTYPE_TO_COMPUTATION_DTYPE = {
    torch.bfloat16: torch.float,
    torch.float16: torch.float,
    **{
        dtype: dtype
        for dtype in [
            torch.bool,
            torch.float32,
            torch.float64,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.uint8,
            torch.uint16,
            torch.uint32,
            torch.uint64,
        ]
    },
}


class DataTypePropagation:
    def __init__(self, body) -> None:
        self.body = body
        self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
            "root": body.root_block.graph
        }
        for k, v in body.subblocks.items():
            self.graphs[k] = v.graph

    def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
        inputs = node.all_input_nodes
        input_nodes = [
            n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
        ]
        if len(input_nodes) == 0:
            return None

        all_input_nodes_propogated = all(
            OptimizationContext.key in n.meta
            and n.meta[OptimizationContext.key].dtype is not None
            for n in input_nodes
        )
        if not all_input_nodes_propogated:
            return None

        return functools.reduce(
            torch.promote_types,
            [n.meta[OptimizationContext.key].dtype for n in input_nodes],
        )

    def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
        sub_graph = self.graphs[node.target]
        dtype = self.propagate_graph(sub_graph)
        assert dtype
        return dtype

    def deduce_node_dtype(self, node: torch.fx.Node):
        if node.target in boolean_ops():
            return torch.bool

        if node.op == "placeholder":
            return None

        if node.target == "output":
            # we can infer output node if it only have 1 arg
            if len(node.args) != 1:
                return None

        if node.target in (
            "to_dtype",
            "index_expr",
        ):
            return node.args[-1]

        if node.target in (
            "rand",
            "randn",
        ):
            return torch.float

        if node.target in (
            "get_index",
            "index_expr",
        ):
            return torch.int64

        if node.target in (
            "load",
            "store",
            "store_reduction",
        ):
            buf_name = node.args[1]
            return V.graph.get_dtype(buf_name)  # type: ignore[arg-type]

        if node.target == operator.getitem:
            return self.deduce_node_dtype(node.args[0])  # type: ignore[arg-type]

        assert isinstance(node.target, str)

        if node.target == "reduction":
            return node.args[1]

        if node.target == "constant":
            return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]]  # type: ignore[index]

        if node.target.startswith("masked_subblock"):
            return self.deduce_node_dtype_by_subgraph(node)

        return self.deduce_node_dtype_by_inputs(node)

    def propagate_graph(self, graph: torch.fx.Graph):
        assert graph.nodes
        graph_dtype = None
        # For masked_subblock, we use output's dtype to represent
        # the dtype of this subgraph. For other cases, graph_dtype
        # might be None
        for node in graph.nodes:
            if OptimizationContext.key in node.meta:
                opt_ctx = node.meta[OptimizationContext.key]
            else:
                opt_ctx = OptimizationContext()

            opt_ctx.dtype = self.deduce_node_dtype(node)
            node.meta[OptimizationContext.key] = opt_ctx
            if node.target == "output":
                graph_dtype = opt_ctx.dtype
        return graph_dtype

    def propagate(self):
        self.propagate_graph(self.graphs["root"])

    @classmethod
    def propagate_loopbody(cls, body):
        return cls(body).propagate()

    @classmethod
    def propagate_scheduler_node(cls, node):
        from ..ir import LoopBody
        from ..scheduler import SchedulerNode

        assert isinstance(node, SchedulerNode)
        assert isinstance(node._body, LoopBody)
        DataTypePropagation.propagate_loopbody(node._body)


class ExprPrinter(Printer):
    @staticmethod
    def paren(string):
        def all_in_parens(string):
            if string[0] != "(" or len(string) < 2:
                return False
            count = 1
            for i, char in enumerate(string[1:]):
                if char == "(":
                    count += 1
                elif char == ")":
                    count -= 1
                if count == 0 and i != len(string) - 2:
                    return False
            assert count == 0
            return True

        if (
            isinstance(string, CSEVariable)
            or re.match(r"^[a-z0-9_.]+$", string, re.I)
            or re.match(r"^\([^)]*\)$", string, re.I)
            or string == ""
        ):
            return string
        # don't put extra parens for strings that are already wrapped in parens
        if all_in_parens(string):
            return string
        return f"({string})"

    def _print_Infinity(self, expr):
        return "math.inf"

    def _print_NegativeInfinity(self, expr):
        return "-math.inf"

    def _print_Relational(self, expr):
        return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))

    def _print_Mul(self, expr):
        return "*".join(map(self.paren, map(self._print, expr.args)))

    def _print_Add(self, expr):
        return " + ".join(map(self.paren, map(self._print, expr.args)))

    def _print_Mod(self, expr):
        return " % ".join(map(self.paren, map(self._print, expr.args)))

    def _print_FloorDiv(self, expr):
        raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")

    def _print_CleanDiv(self, expr):
        return self._print_FloorDiv(expr)

    def _print_GreaterThan(self, expr):
        # GreaterThan:          >=
        # StrictlyGreaterThan:  >
        # Go figure...
        return " >= ".join(map(self.paren, map(self._print, expr.args)))

    def _print_align(self, expr):
        assert len(expr.args) == 1
        return f"align({self._print(expr.args[0])})"


class PythonPrinter(ExprPrinter):
    def _print_ModularIndexing(self, expr):
        x, div, mod = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        mod = self.paren(self.doprint(mod))
        if div != "1":
            x = f"({x} // {div})"
        return f"{x} % {mod}"

    def _print_FloorDiv(self, expr):
        x, div = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        return f"({x} // {div})"

    def _helper_sqrt(self, expr):
        return f"math.sqrt({self._print(expr)})"

    def _print_Pow(self, expr):
        # Pow() confuses triton
        base, exp = expr.args
        # NB: Remember this is sizevar computation!  You don't typically
        # expect to have to do floating point computation including exponents
        # in sizevar compute.  Instead of adding support for floating
        # point pow, you should make upstream retranslate the Sympy expression
        # into Tensor expressions earlier and do that instead.
        if exp == 0.5:
            return self._helper_sqrt(base)
        elif exp == -0.5:
            return "1/" + self._helper_sqrt(base)
        base = self._print(base)
        assert exp == int(exp), exp
        exp = int(exp)
        if exp > 0:
            return "*".join([self.paren(base)] * exp)
        elif exp < 0:
            return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
        else:  # exp == 0
            return "1"

    def _print_floor(self, expr):
        assert len(expr.args) == 1
        return f"math.floor({self._print(expr.args[0])})"

    def _print_ceiling(self, expr):
        assert len(expr.args) == 1
        return f"math.ceil({self._print(expr.args[0])})"

    def _print_Abs(self, expr):
        assert len(expr.args) == 1
        return f"abs({self._print(expr.args[0])})"

    def _print_Max(self, expr):
        assert len(expr.args) >= 2
        return f"max({', '.join(map(self._print, expr.args))})"

    def _print_Min(self, expr):
        assert len(expr.args) >= 2
        return f"min({', '.join(map(self._print, expr.args))})"

    def _print_cos(self, expr):
        assert len(expr.args) == 1
        return f"math.cos({self._print(expr.args[0])})"

    def _print_cosh(self, expr):
        assert len(expr.args) == 1
        return f"math.cosh({self._print(expr.args[0])})"

    def _print_acos(self, expr):
        assert len(expr.args) == 1
        return f"math.acos({self._print(expr.args[0])})"

    def _print_sin(self, expr):
        assert len(expr.args) == 1
        return f"math.sin({self._print(expr.args[0])})"

    def _print_sinh(self, expr):
        assert len(expr.args) == 1
        return f"math.sinh({self._print(expr.args[0])})"

    def _print_asin(self, expr):
        assert len(expr.args) == 1
        return f"math.asin({self._print(expr.args[0])})"

    def _print_tan(self, expr):
        assert len(expr.args) == 1
        return f"math.tan({self._print(expr.args[0])})"

    def _print_tanh(self, expr):
        assert len(expr.args) == 1
        return f"math.tanh({self._print(expr.args[0])})"

    def _print_atan(self, expr):
        assert len(expr.args) == 1
        return f"math.atan({self._print(expr.args[0])})"

    def _print_Round(self, expr):
        assert len(expr.args) == 1
        return f"round({self._print(expr.args[0])})"

    def _print_RoundDecimal(self, expr):
        assert len(expr.args) == 2
        number, ndigits = expr.args
        assert isinstance(ndigits, sympy.Integer)
        return f"round({self._print(number)}, {ndigits})"


class OpOverrides:
    def __init__(self, parent):
        super().__init__()
        self._parent = parent

    def __getattr__(self, item):
        return getattr(self._parent, item)

    @staticmethod
    def identity(value):
        # used to trigger cse
        return value

    @staticmethod
    def constant(value, dtype):
        return repr(value)

    @staticmethod
    def reciprocal(x):
        return ops.truediv("1", x)

    @staticmethod
    def square(x):
        return ops.mul(x, x)

    @staticmethod
    def bitwise_not(x):
        return f"~{ExprPrinter.paren(x)}"

    @staticmethod
    def logical_not(a):
        return f"{ExprPrinter.paren(a)} == 0"

    @staticmethod
    def bitwise_and(x, y):
        return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_or(x, y):
        return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_xor(x, y):
        return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_left_shift(x, y):
        return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_right_shift(x, y):
        return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"

    @staticmethod
    def remainder(a, b):
        r = ops.mod(a, b)
        return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)

    @staticmethod
    def load_seed(name, offset):
        return ops.load(name, sympy.Integer(offset))

    @classmethod
    def _initialize_pointwise_overrides(cls, target):
        assert target in {"triton", "cpp", "cppvec"}, target

        def pointwise_factory_1(impl):
            def func(x):
                return impl.format(x=x)

            return func

        def pointwise_factory_2(impl):
            def func(x, y):
                return impl.format(x=x, y=y)

            return func

        for funcname, data in pointwise_overrides_data.items():
            impl = getattr(data, target)
            if isinstance(impl, str):
                nof_args = 2 if "{y}" in impl else 1
                # extend the following dictionary with factory
                # functions for a specific number of arguments as
                # needed:
                factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
                setattr(cls, funcname, staticmethod(factory(impl)))


@dataclasses.dataclass
class OverridesData:
    name: str
    cpp: str
    triton: Optional[str] = None  # None when not impl in libdevice/triton
    cppvec: Optional[str] = None  # None when not impl in aten/.../vec
    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
    )


pointwise_overrides_data: Dict[str, OverridesData] = dict(
    airy_ai=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="airy_ai_forward({x})",
        name="special_airy_ai",
    ),
    bessel_j0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="bessel_j0_forward({x})",
        triton="libdevice.j0({x})",
        name="special_bessel_j0",
    ),
    bessel_j1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="bessel_j1_forward({x})",
        triton="libdevice.j1({x})",
        name="special_bessel_j1",
    ),
    bessel_y0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="bessel_y0_forward({x})",
        triton="libdevice.y0({x})",
        name="special_bessel_y0",
    ),
    bessel_y1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="bessel_y1_forward({x})",
        triton="libdevice.y1({x})",
        name="special_bessel_y1",
    ),
    digamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_digamma({x})",
        cppvec="{x}.digamma()",
        name="digamma",
    ),
    # no cpp nor triton implementation for entr, it is defined as decomposition
    # erf, erfc
    erfcx=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_erfcx({x})",
        triton="libdevice.erfcx({x})",
        name="special_erfcx",
    ),
    # erfinv, exp2, expit, gammaln
    igamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_igamma({x}, {y})",
        name="igamma",
    ),
    igammac=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_igammac({x}, {y})",
        name="igammac",
    ),
    gammainc=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_igamma({x}, {y})",
        name="special_gammainc",
    ),
    gammaincc=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_igammac({x}, {y})",
        name="special_gammaincc",
    ),
    i0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_i0({x})",
        triton="libdevice.cyl_bessel_i0({x})",
        cppvec="{x}.i0()",
        name="i0",
    ),
    i0e=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_i0e({x})",
        cppvec="{x}.i0e()",
        name="special_i0e",
    ),
    i1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_i1({x})",
        triton="libdevice.cyl_bessel_i1({x})",
        name="special_i1",
    ),
    i1e=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_i1e({x})",
        name="special_i1e",
    ),
    log_ndtr=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_log_ndtr({x})",
        name="special_log_ndtr",
    ),
    # logit
    modified_bessel_i0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="modified_bessel_i0_forward({x})",
        triton="libdevice.cyl_bessel_i0({x})",
        name="special_modified_bessel_i0",
    ),
    modified_bessel_i1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="modified_bessel_i1_forward({x})",
        triton="libdevice.cyl_bessel_i1({x})",
        name="special_modified_bessel_i1",
    ),
    modified_bessel_k0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="modified_bessel_k0_forward({x})",
        name="special_modified_bessel_k0",
    ),
    modified_bessel_k1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="modified_bessel_k1_forward({x})",
        name="special_modified_bessel_k1",
    ),
    # multigamma
    ndtr=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_ndtr({x})",
        name="special_ndtr",
    ),
    ndtri=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_ndtri({x})",
        name="special_ndtri",
    ),
    polygamma=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="calc_polygamma({y}, {x})",
        name="polygamma",
    ),
    # psi - alias to digamma
    # round
    scaled_modified_bessel_k0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="scaled_modified_bessel_k0_forward({x})",
        name="special_scaled_modified_bessel_k0",
    ),
    scaled_modified_bessel_k1=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="scaled_modified_bessel_k1_forward({x})",
        name="special_scaled_modified_bessel_k1",
    ),
    # sinc
    spherical_bessel_j0=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="spherical_bessel_j0_forward({x})",
        name="special_spherical_bessel_j0",
    ),
    zeta=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="zeta({x}, {y})",
        name="special_zeta",
    ),
    chebyshev_polynomial_t=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="chebyshev_polynomial_t_forward({x}, {y})",
        name="special_chebyshev_polynomial_t",
    ),
    chebyshev_polynomial_u=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="chebyshev_polynomial_u_forward({x}, {y})",
        name="special_chebyshev_polynomial_u",
    ),
    chebyshev_polynomial_v=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="chebyshev_polynomial_v_forward({x}, {y})",
        name="special_chebyshev_polynomial_v",
    ),
    chebyshev_polynomial_w=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="chebyshev_polynomial_w_forward({x}, {y})",
        name="special_chebyshev_polynomial_w",
    ),
    legendre_polynomial_p=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="legendre_polynomial_p_forward({x}, {y})",
        name="special_legendre_polynomial_p",
    ),
    shifted_chebyshev_polynomial_t=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_t",
    ),
    shifted_chebyshev_polynomial_u=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_u",
    ),
    shifted_chebyshev_polynomial_v=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_v",
    ),
    shifted_chebyshev_polynomial_w=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
        name="special_shifted_chebyshev_polynomial_w",
    ),
    hermite_polynomial_h=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="hermite_polynomial_h_forward({x}, {y})",
        name="special_hermite_polynomial_h",
    ),
    hermite_polynomial_he=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="hermite_polynomial_he_forward({x}, {y})",
        name="special_hermite_polynomial_he",
    ),
    laguerre_polynomial_l=OverridesData(
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
        cpp="laguerre_polynomial_l_forward({x}, {y})",
        name="special_laguerre_polynomial_l",
    ),
)


# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
    return h


class DeferredLine(DeferredLineBase):
    """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

    def __init__(self, name, line):
        super().__init__(line)
        self.name = name
        assert not isinstance(line, DeferredLineBase)

    def __call__(self):
        if all(
            self.name not in x
            for x in (
                V.graph.removed_buffers,
                V.kernel.removed_buffers,
                V.graph.inplaced_to_remove,
                V.kernel.inplaced_to_remove,
            )
        ):
            return self.line
        return None

    def _new_line(self, line):
        return DeferredLine(self.name, line)


class BracesBuffer(IndentedBuffer):
    def indent(self, offset=1):
        @contextlib.contextmanager
        def ctx():
            for _ in range(offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(-offset):
                self._indent -= 1
                self.writeline("}")
            yield
            for _ in range(-offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(offset):
                self._indent -= 1
                self.writeline("}")

        return ctx()


class InplacedBuffer(NamedTuple):
    inner_name: str
    other_names: List[str]


class KernelArgs:
    @staticmethod
    def _lookup(prefix, odict, name):
        assert isinstance(name, (str, sympy.Symbol))
        if name not in odict:
            odict[name] = f"{prefix}{len(odict)}"
        return odict[name]

    def __init__(self, sizevars=None):
        self.input_buffers = dict()
        self.output_buffers = dict()
        self.inplace_buffers = dict()
        self.sizevars = sizevars or dict()
        self.workspace_arg = None

    def __repr__(self):
        return "KernelArgs({})".format(
            ", ".join(
                map(
                    repr,
                    [
                        self.input_buffers,
                        self.output_buffers,
                        self.inplace_buffers,
                        self.sizevars,
                    ],
                )
            )
        )

    def _buffer_is_marked_removed(self, name):
        return isinstance(name, str) and name.startswith("REMOVED")

    def input(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.output_buffers:
            return self.output_buffers[name]
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        if name.startswith("seed"):
            return self._lookup("seed", self.input_buffers, name)
        return self._lookup("in_ptr", self.input_buffers, name)

    def output(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        return self._lookup("out_ptr", self.output_buffers, name)

    def make_inplace(self, input_name, output_name):
        assert output_name not in self.inplace_buffers
        if input_name in self.inplace_buffers:
            buf = self.inplace_buffers[input_name]
            buf.other_names.append(output_name)
            self.inplace_buffers[output_name] = buf
        else:
            buf = InplacedBuffer(
                f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
                [input_name, output_name],
            )
            self.inplace_buffers[input_name] = buf
            self.inplace_buffers[output_name] = buf

    def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
        if self.workspace_arg is None:
            self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
            return "ws_ptr", 0

        offset = self.workspace_arg.nbytes
        zero_fill = zero_fill or self.workspace_arg.zero_fill
        self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
        return "ws_ptr", offset

    def seed_offset(self, name, value):
        if value in self.sizevars:
            return self.sizevars[value]
        if name in self.sizevars.values():
            name = (
                f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
            )
        self.sizevars[value] = name
        return name

    def size(self, name):
        if str(name) == "seed":
            self.sizevars["seed"] = "seed"
            return "seed"
        return self._lookup("ks", self.sizevars, name)

    def call_names(self):
        return chain(
            self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
        )

    def wrap_ptr_arg(self, buf, dtype):
        return buf

    def wrap_size_arg(self, size):
        return str(size)

    def cpp_argdefs(self):
        from .cpp import DTYPE_TO_CPP, INDEX_TYPE

        call_args = []
        arg_defs = []
        arg_types = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            outer = inplaced.other_names[-1]
            inner = inplaced.inner_name
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.input_buffers.items():
            if outer in self.inplace_buffers:
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"const {cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"const {cpp_dtype}*")
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.sizevars.items():
            arg_defs.append(f"const {INDEX_TYPE} {inner}")
            call_args.append(self.wrap_size_arg(outer))
            arg_types.append(f"const {INDEX_TYPE}")
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)
        assert self.workspace_arg is None, "Workspace not supported on CPU "
        return arg_defs, call_args, arg_types

    def python_argdefs(self):
        arg_defs = []
        call_args = []
        precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            arg_defs.append(inplaced.inner_name)
            call_args.append(inplaced.other_names[-1])
            precompile_args.append(
                TensorArg(
                    name=inplaced.inner_name,
                    buffer=inplaced.other_names[-1],
                    dtype=V.graph.get_dtype(inplaced.other_names[-1]),
                )
            )
        for outer, inner in chain(
            self.input_buffers.items(), self.output_buffers.items()
        ):
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            arg_defs.append(inner)
            call_args.append(outer)
            precompile_args.append(
                TensorArg(
                    name=inner,
                    buffer=outer,
                    dtype=V.graph.get_dtype(outer),
                )
            )
        for outer, inner in self.sizevars.items():
            arg_defs.append(inner)
            call_args.append(outer)
            precompile_args.append(SizeArg(inner, outer))
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)
        if self.workspace_arg is not None:
            arg_defs.append("ws_ptr")
            call_args.append("workspace")
            precompile_args.append(self.workspace_arg)

        return arg_defs, call_args, precompile_args

    def aliases(self):
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            for other in inplaced.other_names:
                if (
                    other in V.graph.inplaced_to_remove
                    or other in V.kernel.inplaced_to_remove
                ):
                    continue
                if other in self.input_buffers:
                    yield self.input_buffers[other], inplaced.inner_name
                if other in self.output_buffers:
                    yield self.output_buffers[other], inplaced.inner_name

    def is_removed(self, name):
        def _is_removed(name, buffers):
            return name not in buffers or self._buffer_is_marked_removed(buffers[name])

        return _is_removed(name, self.output_buffers) and _is_removed(
            name, self.inplace_buffers
        )

    # Includes inplace buffers, excludes removed buffers.  Essentially,
    # after you do a call into this kernel, which buffers actually contain
    # updated data?  Modeled off of python_argdefs.
    def live_output_buffers(self):
        live_outs = set()
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            live_outs.add(inplaced.other_names[-1])
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            live_outs.add(outer)
        return live_outs


class CSEVariable:
    """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
    To do so, the backends can simply overload `Kernel.create_cse_var`
    The "CSEVariable.update_on_args" method gives you a hook for annotations
    See example of TritonCSEVariable in triton.py
    """

    def __init__(self, name, bounds: ValueRanges[Any]):
        assert isinstance(bounds, ValueRanges)
        self.name = name
        self.bounds = bounds

    def __str__(self):
        return self.name

    def __hash__(self) -> int:
        return hash(self.name)

    def __eq__(self, other) -> bool:
        return type(other) == type(self) and other.name == self.name

    def update_on_args(self, name, args, kwargs):
        pass


class CppWrapperKernelArgs(KernelArgs):
    def wrap_ptr_arg(self, buf, dtype):
        from .cpp import DTYPE_TO_CPP

        if config.abi_compatible:
            # In the abi_compatible model, we just return the buf here.
            # We will form correct call args later in wrapper.generate_kernel_all.
            return buf
        else:
            return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"

    def wrap_size_arg(self, size):
        return f"{size}"


class CSE:
    """Common subexpression elimination"""

    def __init__(
        self,
        prefix="",
        suffix="",
        name_prefix="tmp",
        iter_buffers=None,
        store_cache=None,
        reduction_cache=None,
        varname_map=None,
    ):
        self.prefix = prefix
        self.suffix = suffix
        self.cache = {}
        self.name_prefix = name_prefix
        self.store_cache = store_cache or {}
        self.reduction_cache = reduction_cache or {}
        self.iter_buffer_ids = iter_buffers or itertools.count()
        self.invalidated_stores = set()
        self.varname_map = varname_map or {}

    def invalidate(self, keep_vars: Set[str]):
        for name, tmp in list(self.store_cache.items()):
            if tmp not in keep_vars:
                del self.store_cache[name]
                self.invalidated_stores.add(name)
        self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}

    def clone(self):
        # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
        return CSE(
            prefix=self.prefix,
            suffix=self.suffix,
            name_prefix=self.name_prefix,
            iter_buffers=self.iter_buffer_ids,
            store_cache=self.store_cache,
            varname_map=self.varname_map,
        )

    def generate(
        self,
        buffer: IndentedBuffer,
        expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
        *,
        bounds: ValueRanges[Any] = ValueRanges.unknown(),
        write=True,
        assignment=True,
    ) -> CSEVariable:
        if isinstance(expr, OpsValue):
            expr = expr.value

        assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
        assert write or assignment
        if isinstance(expr, CSEVariable):
            # If the expressions were always created with all the information, we could
            # assert expr.bounds == bounds, but sometimes the expression is created
            # with the loose ValueRanges.unknown(), so we need to tighten the bounds
            expr.bounds = expr.bounds.tighten(bounds)
            return expr
        cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
        var = self.cache.get(cache_key, None)
        if not var:
            var = self.newvar(bounds) if assignment else None
            self.cache[cache_key] = var
            if write:
                if V.kernel.current_node:
                    V.kernel.current_node.codegen_originating_info(
                        buffer, only_once=True
                    )
                if isinstance(expr, IndentedBuffer):
                    if assignment:
                        buffer.writeline(f"{self.prefix}{var} =")
                    buffer.splice(expr)
                    buffer.writeline(self.suffix)
                else:
                    if assignment:
                        line = f"{self.prefix}{var} = {expr}{self.suffix}"
                    else:
                        line = f"{expr}{self.suffix}"
                    buffer.writeline(line)
        else:
            var.bounds = var.bounds.tighten(bounds)

        return var

    def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
        var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
        var = V.kernel.create_cse_var(var_name, bounds)
        self.varname_map[var_name] = var
        return var


class IndirectAssertLine(DeferredLineBase):
    def __init__(self, line, assert_fn, var, mask, size_map):
        self.var = var
        self.mask = mask
        self.line = line
        self.assert_fn = assert_fn
        self.size_map = size_map

    def __call__(self):
        size, size_str = self.size_map[(self.var, self.mask)]

        # We assert if we've not been able to prove the bound
        assert_min = (self.var.bounds.lower >= 0) != sympy.true
        assert_max = (self.var.bounds.upper < size) != sympy.true

        # FooBar interview question
        if not (assert_min or assert_max):
            return None
        elif assert_min and assert_max:
            # The conditions need to be in parens because of Python's operator precedence.
            # It'd be less error-prone to use and/or/not, which is suported by triton
            cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
            cond_print = f"0 <= {self.var} < {size_str}"
        elif assert_min:
            cond = f"0 <= {self.var}"
            cond_print = cond
        else:
            assert assert_max
            cond = f"{self.var} < {size_str}"
            cond_print = cond

        if self.mask:
            cond = f"({cond}) | ~{self.mask}"
        return self.line.format(
            assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
        )

    def _new_line(self, line):
        return IndirectAssertLine(
            line, self.assert_fn, self.var, self.mask, self.size_map
        )


class CodeGen:
    def __init__(self):
        super().__init__()
        self.exit_stack = contextlib.ExitStack()

    def __enter__(self):
        self.exit_stack.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)


class Kernel(CodeGen):
    newvar_prefix = ""
    suffix = ""
    overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
    # TODO: these look dead, but with all the getattr it's hard to tell...
    load_format: None = None
    store_format: None = None

    def __init__(self, args=None, increase_kernel_count=True):
        super().__init__()
        if increase_kernel_count:
            metrics.generated_kernel_count += 1
        self.args = args or KernelArgs()
        self.loads = IndentedBuffer()
        self.compute = IndentedBuffer()
        self.stores = IndentedBuffer()
        self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
        self.must_keep_buffers = set()
        self.store_buffer_names = set()
        self._load_mask = None
        # set in set_current_node
        self.current_node = None
        self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
        # Upper bounds for indirect_indexing and their str representation
        # NB: None, None is never stored in map, but it is the assumed
        # "not set" value for the dict
        self.indirect_max_sizes: Dict[
            Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
        ] = {}

        self.removed_buffers = set()
        self.inplaced_to_remove = set()

        # key: the buffer to write
        # value: the buffer to read and whose memory can be reused for
        #   the buffer specified by key
        self.inplace_update_buffers = dict()
        # Set minimum number of elements processed per thread.
        self.min_elem_per_thread = 1
        self.kernel_name = None

    @contextlib.contextmanager
    def set_current_node(self, node):
        prior = self.current_node
        self.current_node = node
        self.node_to_bounds = node._body.bounds().get_bounds()
        try:
            yield
        finally:
            self.current_node = prior

    @contextlib.contextmanager
    def swap_buffers(self, lb, cb=None, sb=None):
        if cb is None:
            cb = lb
        loads = self.loads
        compute = self.compute
        stores = self.stores
        cse = self.cse
        self.loads = lb
        self.compute = cb
        self.stores = sb
        self.cse = cse.clone()
        try:
            yield
        finally:
            self.loads = loads
            self.compute = compute
            self.stores = stores
            self.cse = cse

    def load(self, name: str, index: sympy.Expr) -> CSEVariable:
        raise NotImplementedError()

    def indirect_load(self, name: str, index: sympy.Expr):
        """A load the depends on an index we have read"""
        prior = self.loads
        try:
            # put the load in the compute section as it might have deps
            self.loads = self.compute
            return self.load(name, index)
        finally:
            self.loads = prior

    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
        raise NotImplementedError()

    def store(
        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
    ) -> None:
        raise NotImplementedError()

    def reduction(
        self,
        dtype: torch.dtype,
        src_dtype: torch.dtype,
        reduction_type: ReductionType,
        value: Union[CSEVariable, Tuple[CSEVariable, ...]],
    ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
        raise NotImplementedError()

    def scan(
        self,
        dtype: torch.dtype,
        combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
        value: CSEVariable,
        init: int,
    ) -> CSEVariable:
        raise NotImplementedError()

    def bucketize(
        self,
        values: CSEVariable,
        offsets_name: str,
        offsets_size: sympy.Expr,
        indexing_dtype: torch.dtype,
        right: bool,
    ) -> CSEVariable:
        """
        See [Note: Inductor bucketize op]
        """
        raise NotImplementedError()

    @property
    def assert_function(self) -> str:
        raise NotImplementedError()

    def index_to_str(self, index: sympy.Expr) -> str:
        raise NotImplementedError()

    def __enter__(self):
        # TODO: hoist this to top level
        class CSEProxy:
            self.name = "CSEProxy"

            @staticmethod
            def __getattr__(name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc]
                def inner(*args, **kwargs):
                    # TritonTemplateKernel has no current_node
                    buf_bounds = ValueRanges.unknown()
                    if hasattr(V.interpreter, "current_node"):
                        fx_node = V.interpreter.current_node
                        assert isinstance(self.node_to_bounds, dict)
                        buf_bounds = self.node_to_bounds.get(
                            fx_node, ValueRanges.unknown()
                        )

                    value = getattr(parent_handler, name)(*args, **kwargs)  # type: ignore[has-type]

                    def do_cse(v):
                        csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
                        csevar.update_on_args(name, args, kwargs)
                        return csevar

                    return pytree.tree_map(do_cse, value)

                return inner

            @staticmethod
            def indirect_indexing(
                var: CSEVariable, size: sympy.Expr, check: bool = True
            ):
                # Skip CSE since this doesn't return an expression

                if var.bounds.lower < 0:  # type: ignore[operator]
                    new_bounds = ValueRanges.unknown()
                    if var.bounds != ValueRanges.unknown() and isinstance(
                        size, sympy.Number
                    ):
                        # Take the negative part of the bound and add size to it
                        # Then take union of that and the positive part
                        # This is a tighter bound than that of a generic ops.where, as we have info on the cond
                        neg = var.bounds & ValueRanges(-sympy.oo, -1)
                        new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
                        # We don't have a good way of representing the empty range
                        if var.bounds.upper >= 0:  # type: ignore[operator]
                            pos = var.bounds & ValueRanges(0, sympy.oo)
                            new_bounds = new_bounds | pos

                    stm = ops.add(var, self.rename_indexing(size))
                    # Mixed negative and non-negative
                    if var.bounds.upper >= 0:  # type: ignore[operator]
                        lt = ops.lt(var, "0")
                        stm = ops.where(lt, stm, var)
                    new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)

                    new_var.update_on_args("index_wrap", (var,), {})
                    var = new_var

                if self.generate_assert(check):
                    mask = self.load_mask(var)

                    # An assertion line may have been written already, if so just
                    # update the max size.
                    map_key = (var, mask)
                    existing_size, _ = self.indirect_max_sizes.get(
                        map_key, (None, None)
                    )
                    if existing_size is not None:
                        size = sympy.Min(size, existing_size)
                    else:
                        line = (
                            '{assert_fn}({cond}, "index out of bounds: {cond_print}")'
                        )
                        self.compute.writeline(
                            IndirectAssertLine(
                                line,
                                self.assert_function,
                                var,
                                mask,
                                self.indirect_max_sizes,
                            )
                        )

                    self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
                return sympy_index_symbol(str(var))

            @staticmethod
            def load(name: str, index: sympy.Expr) -> CSEVariable:
                if name in self.cse.invalidated_stores:
                    # A load from an invalidated store requires us to
                    # keep the actual buffer around
                    V.kernel.must_keep_buffers.add(name)
                if free_symbol_startswith(index, "tmp"):
                    return self.indirect_load(name, index)
                store_cache = self.cse.store_cache
                if name in store_cache:
                    return store_cache[name]
                return self.load(name, index)

            @staticmethod
            def store(
                name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
            ) -> None:
                self.store_buffer_names.add(name)
                if mode is None:
                    self.cse.store_cache[name] = value
                    if self.current_node:
                        for other_name in self.current_node.get_mutations():
                            self.cse.store_cache[other_name] = value
                if name not in V.graph.removed_buffers:
                    return self.store(name, index, value, mode=mode)
                else:
                    return None  # type: ignore[return-value]

            @staticmethod
            def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
                self.store_buffer_names.add(name)
                self.cse.store_cache[name] = value
                if self.current_node:
                    for other_name in self.current_node.get_mutations():
                        self.cse.store_cache[other_name] = value

                if name not in V.graph.removed_buffers:
                    return self.store_reduction(name, index, value)

            @staticmethod
            def reduction(
                dtype: torch.dtype,
                src_dtype: torch.dtype,
                reduction_type: ReductionType,
                value: Union[CSEVariable, Tuple[CSEVariable, ...]],
            ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
                return self.reduction(dtype, src_dtype, reduction_type, value)

            @staticmethod
            def scan(
                dtype: torch.dtype,
                combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
                value: CSEVariable,
                init: int,
            ) -> CSEVariable:
                return self.scan(dtype, combine_fn, value, init)

            @staticmethod
            def bucketize(
                values: CSEVariable,
                offsets_name: str,
                offsets_size: sympy.Expr,
                indexing_dtype: torch.dtype,
                right: bool,
            ) -> CSEVariable:
                """
                [Note: Inductor bucketize op]

                Given values (tensor) and offsets_name (reference to the name of a 1D
                tensor), calculate the bucket that each value belongs to.

                e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
                return =        [ 0, 1, 1, 1, 1, 3, 3, 4].

                When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
                When right == True,  bucket i refers to range [offsets[i], offsets[i+1]).

                Offsets must be non-decreasing or the result is undefined.
                """
                return self.bucketize(
                    values, offsets_name, offsets_size, indexing_dtype, right
                )

        # Use mypy to check protocol implemented correctly
        def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
            return h

        super().__enter__()
        assert self.overrides
        parent_handler = self.overrides(V.get_ops_handler())
        self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
        self.exit_stack.enter_context(V.set_kernel_handler(self))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Note that V.graph.scheduler can be None when codegening triton template
        kernels.
        """
        if V.graph.scheduler:
            V.graph.scheduler.remove_kernel_local_buffers()
        super().__exit__(exc_type, exc_val, exc_tb)

    def generate_assert(self, check):
        return (check or config.debug_index_asserts) and config.assert_indirect_indexing

    def load_mask(self, var) -> str:
        # only the triton kernel requires mask
        return ""

    def rename_indexing(self, index) -> sympy.Expr:
        # adds the necessary kernel args for index expressions
        # and renames variables in index expressions to kernel arg names
        if isinstance(index, (list, tuple)):
            return [self.rename_indexing(x) for x in index]  # type: ignore[return-value]
        index = V.graph.sizevars.simplify(index)
        sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
        replacements = {
            x: self.args.size(x)
            for x in sorted_symbols
            if x.name.startswith(("s", "u", "ps"))
            or (x.name.startswith("i") and not x.name.startswith("idx"))
        }
        return sympy_subs(index, replacements)

    def create_cse_var(self, *args, **kwargs):
        return CSEVariable(*args, **kwargs)


@dataclasses.dataclass
class OptimizationContext:
    key: ClassVar[str] = "opt_ctx"

    # Load value as mask
    is_load_as_mask: bool = False

    dtype: Optional[torch.dtype] = None
    ops_name: str = ""

    # Load uint8/int8 value as float32
    is_load_int8_as_float: bool = False


@functools.lru_cache(None)
def jinja2_env():
    try:
        import jinja2

        return jinja2.Environment(
            undefined=jinja2.StrictUndefined,
        )
    except ImportError:
        return None


PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]


class ChoiceCaller:
    """
    Represents a possible choice used in autotune_process.py.
    During autotuning, self.benchmark() is first called to get benchmark result,
    and if this choice is selected, self.output_node() is called to get the output_node.

    Children classes: TritonTemplateCaller, CUDATemplateCaller.
    """

    def __init__(self, name, input_nodes, layout):
        super().__init__()
        self.name = name
        self.layout = layout
        self.input_nodes = input_nodes

    def benchmark(self, *args, out) -> float:
        algo = self.to_callable()
        return do_bench(lambda: algo(*args, out=out))

    def call_name(self) -> str:
        raise NotImplementedError()

    def to_callable(self):
        raise NotImplementedError()

    def hash_key(self) -> str:
        raise NotImplementedError()

    def output_node(self) -> "TensorBox":
        raise NotImplementedError()

    def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
        """Information returned here is logged to the autotune log file when that is enabled."""
        return {}


class KernelTemplate:
    """
    Base class for defining kernel templates.

    Children classes: TritonTemplate, CUDATemplate
    """

    @staticmethod
    def _template_from_string(source):
        env = jinja2_env()
        if env is not None:
            return env.from_string(source)
        return None

    @staticmethod
    def _fake_get_dtype(fake_out):
        _get_dtype_real = V.graph.get_dtype

        def get_dtype(name):
            if name == fake_out.get_name():
                return fake_out.get_dtype()
            return _get_dtype_real(name)

        return get_dtype

    def __init__(self, name: str):
        self.name = name

    def maybe_append_choice(self, choices, **kwargs):
        """
        Maybe generates a new ChoiceCaller and appends it into existing choices.

        choices: A list of ChoiceCallers.
        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
        """

        try:
            choices.append(self.generate(**kwargs))
        except NotImplementedError:
            pass

    def generate(self, **kwargs) -> ChoiceCaller:
        """
        Generates a ChoiceCaller instance from the given arguments.
        """

        raise NotImplementedError()
