# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras Utilities for DTensor related API."""

import inspect

import tensorflow.compat.v2 as tf

from tf_keras.src.dtensor import dtensor_api as dtensor

# All the variable names in the default keras layers. We will use those to map
# against the args in the __init__ method to find corresponding layout args.
# See allow_layout() for more details.
KERAS_VARIABLE_NAMES = [
    "alpha",
    "beta",
    "bias",
    "depthwise",
    "embeddings",
    "gamma",
    "kernel",
    "moving_mean",
    "moving_variance",
    "pointwise",
    "recurrent",
]


def allow_initializer_layout(init_method):
    """A decorator for injecting layout information to layer.__init__.

    Layout will be a new param for any of the weights for all the keras layers.
    Adding the param to all the __init__ method will be a big/duplicated work.

    This decorator is design to reduce and code duplication and make it easy to
    add/remove the dtensor feature if needed.

    Sample usage:
    ```python
    class Dense(tf.keras.layer.Layer):

      @allow_initializer_layout
      def __init__(self, units,
                   kernel_initializer='zeros',
                   bias_initializer='zeros',
                   **kwargs):
         super().__init__(**kwargs)

    d = Dense(units=8, kernel_layout=layout1, bias_layout=layout2)
    d.kernel_layout == layout1
    d.bias_layout == layout2
    ```

    By adding this annotation, it will:

    1. Filter out the kwargs based on some keywords, eg if the
      'kernel_initialzer' appears in method signature, then it will try to pop
      the 'kernel_layout' if it presents. Same for "bias" and
      "recurrent_kernel", etc. This will make sure the layout related param is
      not passed to `BaseLayer.__init__`, which will raise error about unexpect
      keyword args.
    2. Set the self.kernel/bias_layout attribute after the `__init__` method is
       called. TF-Keras framework will use those fields to create weights down
       the stream.

    Args:
      init_method: the `__init__` method of the TF-Keras layer to annotate.

    Returns:
      the annotated __init__ method.
    """

    def _wrap_function(layer_instance, *args, **kwargs):
        signature = inspect.signature(init_method)
        layout_args = {}
        # Check args like 'kernel_initializer' and pop the 'kernel_layout' if it
        # presents.
        for variable_name in KERAS_VARIABLE_NAMES:
            if variable_name + "_initializer" in signature.parameters:
                layout = kwargs.pop(variable_name + "_layout", None)
                if layout:
                    layout_args[variable_name + "_layout"] = layout

        init_method(layer_instance, *args, **kwargs)

        # Inject the layout parameter after the invocation of __init__()
        for layout_param_name, layout in layout_args.items():
            setattr(layer_instance, layout_param_name, layout)

    # return decorated
    return tf.__internal__.decorator.make_decorator(
        target=init_method, decorator_func=_wrap_function
    )


def inject_mesh(init_method):
    """Inject DTensor mesh information to an object.

    This is useful for keras object like `Metric` and `Optimizer` which need
    DTensor mesh to create the weights, but doesn't want to change the current
    public API interface.

    This is for temporary usage and eventually the mesh/layout information will
    be public arguments in the `__init__` method.

    Sample usage:
    ```python
    class Accuracy(tf.keras.metrics.Metric):

      @inject_mesh
      def __init__(self, name='accuracy', dtype=None):
         super().__init__(**kwargs)

      acc = Accuracy(mesh=mesh)
      assert acc._mesh == mesh
    ```

    Args:
      init_method: the `__init__` method of the TF-Keras class to annotate.

    Returns:
      the annotated __init__ method.
    """

    def _wrap_function(instance, *args, **kwargs):
        mesh = kwargs.pop("mesh", None)
        # Note that the injection of _mesh need to happen before the invocation
        # of __init__, since the class might need the mesh to create weights in
        # the __init__.
        if mesh is not None:
            instance._mesh = mesh
        init_method(instance, *args, **kwargs)

    return tf.__internal__.decorator.make_decorator(
        target=init_method, decorator_func=_wrap_function
    )


def call_with_layout(fn, layout, *args, **kwargs):
    """Invoke the function with inputs and relayout the result.

    Args:
      fn: the function to invoke.
      layout: if not None, the output of the fn will be relayout with this.
      *args: positional arguments to be called with fn.
      **kwargs: keyword arguments to be called with fn.

    Returns:
      The output of fn, with potential relayout with the layout specified.
    """
    if layout:
        with dtensor.default_mesh(layout.mesh):
            result = fn(*args, **kwargs)
            return dtensor.relayout(result, layout)
    return fn(*args, **kwargs)


def running_with_dtensor_strategy():
    """Check whether running with a `Strategy` that is backed by DTensor.

    In the DTensor based training, all the tensors are in global context, which
    is different from the local context. Some keras components need to
    behave differently, e.g. BatchNormalization and SyncBatchNormalization, as
    well as optimizers.

    This check will help those layer to branch the logic and keep the correct
    behavior between different context.
    """
    if not tf.distribute.has_strategy():
        return False
    strategy = tf.distribute.get_strategy()
    # TODO(scottzhu): Finalize the strategy API to check if a strategy is backed
    # by DTensor.
    return getattr(strategy, "_mesh", None) is not None

