# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Contains the base Layer class, from which all layers inherit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import warnings

import tensorflow.compat.v2 as tf

from tf_keras.src import backend
from tf_keras.src.engine import base_layer_utils
from tf_keras.src.engine import base_layer_v1 as base_layer
from tf_keras.src.legacy_tf_layers import variable_scope_shim
from tf_keras.src.mixed_precision import policy
from tf_keras.src.utils import tf_contextlib

# isort: off
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util.tf_export import keras_export

_KERAS_STYLE_SCOPE = False


@keras_export(
    v1=["keras.__internal__.legacy.layers.experimental.keras_style_scope"]
)
@tf_contextlib.contextmanager
def keras_style_scope():
    """Use Keras-style variable management.

    All tf.layers and tf RNN cells created in this scope use Keras-style
    variable management.  Creating such layers with a scope= argument is
    disallowed, and reuse=True is disallowed.

    The purpose of this scope is to allow users of existing layers to
    slowly transition to a TF-Keras layers API without breaking existing
    functionality.

    One example of this is when using TensorFlow's RNN classes with Keras
    Models or Networks.  Because TF-Keras models do not properly set variable
    scopes, users of RNNs may either accidentally share scopes between two
    different models, or get errors about variables that already exist.

    Example:

    ```python
    class RNNModel(tf.keras.Model):

      def __init__(self, name):
        super(RNNModel, self).__init__(name=name)
        self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
          [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])

      def call(self, input, state):
        return self.rnn(input, state)

    model_1 = RNNModel("model_1")
    model_2 = RNNModel("model_2")

    # OK
    output_1, next_state_1 = model_1(input, state)
    # Raises an error about trying to create an already existing variable.
    output_2, next_state_2 = model_2(input, state)
    ```

    The solution is to wrap the model construction and execution in a
    keras-style scope:

    ```python
    with keras_style_scope():
      model_1 = RNNModel("model_1")
      model_2 = RNNModel("model_2")

      # model_1 and model_2 are guaranteed to create their own variables.
      output_1, next_state_1 = model_1(input, state)
      output_2, next_state_2 = model_2(input, state)

      assert len(model_1.weights) > 0
      assert len(model_2.weights) > 0
      assert(model_1.weights != model_2.weights)
    ```

    Yields:
      A keras layer style scope.
    """
    global _KERAS_STYLE_SCOPE
    stack = _KERAS_STYLE_SCOPE
    _KERAS_STYLE_SCOPE = True
    try:
        yield
    finally:
        _KERAS_STYLE_SCOPE = stack


@keras_export(
    v1=["keras.__internal__.legacy.layers.experimental.set_keras_style"]
)
def set_keras_style():
    """Use Keras-style variable management.

    All tf.layers and tf RNN cells created after keras style ha been enabled
    use Keras-style variable management.  Creating such layers with a
    scope= argument is disallowed, and reuse=True is disallowed.

    The purpose of this function is to allow users of existing layers to
    slowly transition to TF-Keras layers API without breaking existing
    functionality.

    For more details, see the documentation for `keras_style_scope`.

    Note, once keras style has been set, it is set globally for the entire
    program and cannot be unset.

    Example:

    ```python
    set_keras_style()

    model_1 = RNNModel(name="model_1")
    model_2 = RNNModel(name="model_2")

    # model_1 and model_2 are guaranteed to create their own variables.
    output_1, next_state_1 = model_1(input, state)
    output_2, next_state_2 = model_2(input, state)

    assert len(model_1.weights) > 0
    assert len(model_2.weights) > 0
    assert(model_1.weights != model_2.weights)
    ```
    """
    global _KERAS_STYLE_SCOPE
    _KERAS_STYLE_SCOPE = True


def _is_in_keras_style_scope():
    global _KERAS_STYLE_SCOPE
    return _KERAS_STYLE_SCOPE


@keras_export(v1=["keras.__internal__.legacy.layers.Layer"])
class Layer(base_layer.Layer):
    """Base layer class.

    It is considered legacy, and we recommend the use of `tf.keras.layers.Layer`
    instead.

    Args:
      trainable: Boolean, whether the layer's variables should be trainable.
      name: String name of the layer.
      dtype: Default dtype of the layer's weights (default of `None` means use
        the type of the first input).

    Read-only properties:
      name: The name of the layer (string).
      dtype: Default dtype of the layer's weights (default of `None` means use
        the type of the first input).
      trainable_variables: List of trainable variables.
      non_trainable_variables: List of non-trainable variables.
      variables: List of all variables of this layer, trainable and
        non-trainable.
      updates: List of update ops of this layer.
      losses: List of losses added by this layer.
      trainable_weights: List of variables to be included in backprop.
      non_trainable_weights: List of variables that should not be
        included in backprop.
      weights: The concatenation of the lists trainable_weights and
        non_trainable_weights (in this order).

    Mutable properties:
      trainable: Whether the layer should be trained (boolean).
      input_spec: Optional (list of) `InputSpec` object(s) specifying the
        constraints on inputs that can be accepted by the layer.
    """

    def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
        # For backwards compatibility, legacy layers do not use
        # `ResourceVariable` by default.
        self._use_resource_variables = False
        scope = kwargs.pop("_scope", None)
        self._reuse = kwargs.pop("_reuse", None)

        # Avoid an incorrect lint error
        self._trainable_weights = []
        self.built = False

        if dtype is None:
            # Indicates to infer dtype from inputs. When the V2 dtype behavior
            # is enabled, TF-Keras layers default their dtype to floatx instead,
            # so we pass an "_infer" policy to keep the old V1 behavior.
            dtype = policy.Policy("_infer")

        if "autocast" not in kwargs:
            kwargs["autocast"] = False

        # Mark that legacy layers should not be instrumented as TF-Keras usage
        self._disable_keras_instrumentation = True

        super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs)

        if _is_in_keras_style_scope():
            if scope is not None:
                raise ValueError(
                    "scope argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(scope)
                )
            if self._reuse is not None:
                raise ValueError(
                    "reuse argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(self._reuse)
                )
            self._keras_style = True
        else:
            self._keras_style = False

        self._call_has_scope_arg = "scope" in self._call_spec.arg_names
        if scope:
            with tf.compat.v1.variable_scope(scope) as captured_scope:
                self._scope = captured_scope
        else:
            self._scope = None
        self._current_scope = None

    def apply(self, *args, **kwargs):
        return self(*args, **kwargs)

    # We no longer track graph in tf.layers layers. This property is only kept
    # to maintain API backward compatibility.
    @property
    def graph(self):
        warnings.warn(
            "`Layer.graph` is deprecated and "
            "will be removed in a future version. "
            "Please stop using this property because tf.layers layers no "
            "longer track their graph.",
            stacklevel=2,
        )
        if tf.executing_eagerly():
            raise RuntimeError(
                "Layer.graph not supported when executing eagerly."
            )
        return None

    def _init_set_name(self, name):
        # Determine layer name (non-unique).
        if isinstance(name, tf.compat.v1.VariableScope):
            base_name = name.name
            self._name, _ = self._make_unique_name()
        else:
            base_name = name
            self._name = name
        if not name:
            self._name, base_name = self._make_unique_name()
        self._base_name = base_name

    def _make_unique_name(
        self,
        name_uid_map=None,
        avoid_names=None,
        namespace="",
        zero_based=False,
    ):
        base_name = base_layer.to_snake_case(self.__class__.__name__)
        name = backend.unique_object_name(
            base_name,
            name_uid_map=name_uid_map,
            avoid_names=avoid_names,
            namespace=namespace,
            zero_based=zero_based,
        )
        return (name, base_name)

    @property
    def scope_name(self):
        if not self._scope:
            raise ValueError(
                'No name available for layer scope because the layer "'
                + self._name
                + '" has not been used yet. The scope name '
                + " is determined the first time the layer instance is "
                + "called. You must therefore call the layer before "
                + "querying `scope_name`."
            )
        return self._scope.name

    def add_loss(self, losses, inputs=None):
        previous_losses_length = len(self._losses)
        previous_callable_losses_length = len(self._callable_losses)
        super().add_loss(losses, inputs=inputs)
        if not tf.executing_eagerly():
            # TODO(fchollet): deprecate collection below.
            new_losses = self._losses[previous_losses_length:]
            new_callable_losses = self._callable_losses[
                previous_callable_losses_length:
            ]
            for regularizer in new_callable_losses:
                loss_tensor = regularizer()
                if loss_tensor is not None:
                    new_losses.append(loss_tensor)
            _add_elements_to_collection(
                new_losses, tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES
            )

    def _name_scope(self):
        """Determines op naming for the Layer."""
        if self._keras_style:
            return super()._name_scope()
        return self._current_scope.original_name_scope

    def _set_scope(self, scope=None):
        if self._scope is None:
            # If constructed with _scope=None, lazy setting of scope.
            if self._reuse:
                with tf.compat.v1.variable_scope(
                    scope if scope is not None else self._base_name
                ) as captured_scope:
                    self._scope = captured_scope
            else:
                with tf.compat.v1.variable_scope(
                    scope, default_name=self._base_name
                ) as captured_scope:
                    self._scope = captured_scope

    def add_weight(
        self,
        name,
        shape,
        dtype=None,
        initializer=None,
        regularizer=None,
        trainable=None,
        constraint=None,
        use_resource=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.compat.v1.VariableAggregation.NONE,
        partitioner=None,
        **kwargs
    ):
        """Adds a new variable to the layer, or gets an existing one; returns it

        Args:
          name: variable name.
          shape: variable shape.
          dtype: The type of the variable. Defaults to `self.dtype` or
            `float32`.
          initializer: initializer instance (callable).
          regularizer: regularizer instance (callable).
          trainable: whether the variable should be part of the layer's
            "trainable_variables" (e.g. variables, biases)
            or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
            Note, if the current variable scope is marked as non-trainable
            then this parameter is ignored and any added variables are also
            marked as non-trainable. `trainable` becomes `True` unless
            `synchronization` is set to `ON_READ`. Defaults to `True`.
          constraint: constraint instance (callable).
          use_resource: Whether to use `ResourceVariable`.
          synchronization: Indicates when a distributed a variable will be
            aggregated. Accepted values are constants defined in the class
            `tf.VariableSynchronization`. By default the synchronization is set
            to `AUTO` and the current `DistributionStrategy` chooses when to
            synchronize. If `synchronization` is set to `ON_READ`, `trainable`
            must not be set to `True`.
          aggregation: Indicates how a distributed variable will be aggregated.
            Accepted values are constants defined in the class
            `tf.VariableAggregation`.
          partitioner: (optional) partitioner instance (callable).  If
            provided, when the requested variable is created it will be split
            into multiple partitions according to `partitioner`.  In this case,
            an instance of `PartitionedVariable` is returned.  Available
            partitioners include `tf.compat.v1.fixed_size_partitioner` and
            `tf.compat.v1.variable_axis_size_partitioner`.  For more details,
            see the documentation of `tf.compat.v1.get_variable` and the
            "Variable Partitioners and Sharding" section of the API guide.
          **kwargs: Additional keyword arguments.

        Returns:
          The created variable.  Usually either a `Variable` or
          `ResourceVariable` instance.  If `partitioner` is not `None`, a
          `PartitionedVariable` instance is returned.

        Raises:
          RuntimeError: If called with partitioned variable regularization and
            eager execution is enabled.
          ValueError: When trainable has been set to True with synchronization
            set as `ON_READ`.
        """
        for kwarg in kwargs:
            if kwarg != "experimental_autocast":
                raise TypeError("Unknown keyword argument:", kwarg)
        if self._keras_style:
            return super().add_weight(
                name=name,
                shape=shape,
                dtype=dtype,
                initializer=initializer,
                regularizer=regularizer,
                trainable=trainable and self.trainable,
                constraint=constraint,
                use_resource=use_resource,
                synchronization=tf.VariableSynchronization.AUTO,
                aggregation=tf.compat.v1.VariableAggregation.NONE,
                partitioner=partitioner,
                **kwargs
            )

        if synchronization == tf.VariableSynchronization.ON_READ:
            if trainable:
                raise ValueError(
                    "Synchronization value can be set to "
                    "VariableSynchronization.ON_READ only for non-trainable "
                    "variables. You have specified trainable=True and "
                    "synchronization=VariableSynchronization.ON_READ."
                )
            else:
                # Set trainable to be false when variable is to be synced on
                # read.
                trainable = False
        elif trainable is None:
            trainable = True

        def _should_add_regularizer(variable, existing_variable_set):
            if base_layer_utils.is_split_variable(variable):
                for var in variable:
                    if var in existing_variable_set:
                        return False
                return True
            else:
                return variable not in existing_variable_set

        init_graph = None
        if not tf.executing_eagerly():
            default_graph = tf.compat.v1.get_default_graph()
            if default_graph.building_function:
                with tf.init_scope():
                    # Retrieve the variables from the graph into which variables
                    # will be lifted; if initialization ops will be lifted into
                    # the eager context, then there is nothing to retrieve,
                    # since variable collections are not supported when eager
                    # execution is enabled.
                    if not tf.executing_eagerly():
                        init_graph = tf.compat.v1.get_default_graph()
                        existing_variables = set(
                            tf.compat.v1.global_variables()
                        )
            else:
                # Initialization ops will not be lifted out of the default
                # graph.
                init_graph = default_graph
                existing_variables = set(tf.compat.v1.global_variables())

        if dtype is None:
            dtype = self.dtype or tf.float32

        self._set_scope(None)
        reuse = self.built or self._reuse
        prev_len_trainable = len(self._trainable_weights)
        with tf.compat.v1.variable_scope(
            self._scope, reuse=reuse, auxiliary_name_scope=False
        ) as scope:
            self._current_scope = scope
            with backend.name_scope(self._name_scope()):
                use_resource = (
                    use_resource
                    or self._use_resource_variables
                    or scope.use_resource
                )
                if initializer is None:
                    initializer = scope.initializer
                variable = super().add_weight(
                    name,
                    shape,
                    dtype=tf.as_dtype(dtype),
                    initializer=initializer,
                    trainable=trainable and self.trainable,
                    constraint=constraint,
                    partitioner=partitioner,
                    use_resource=use_resource,
                    synchronization=synchronization,
                    aggregation=aggregation,
                    getter=tf.compat.v1.get_variable,
                    **kwargs
                )

                if regularizer:
                    if (
                        tf.compat.v1.executing_eagerly_outside_functions()
                        or _should_add_regularizer(variable, existing_variables)
                    ):
                        self._handle_weight_regularization(
                            name, variable, regularizer
                        )
                        var_store = vs._get_default_variable_store()
                        # When the shim to get variable scope working in TF2 is
                        # used, We need to explicitly make the shim track the
                        # regularization losses as the collections will not be
                        # accessible.
                        if hasattr(var_store, "add_regularizer"):
                            var_store.add_regularizer(variable, regularizer)

                if init_graph is not None:
                    # Handle edge case where a custom getter has overridden
                    # `trainable`.  There is one known occurrence of this, in
                    # unit test testBasicRNNCellNotTrainable in
                    # contrib.rnn.python.kernel_tests.core_rnn_cell_test
                    with init_graph.as_default():
                        trainable_variables = tf.compat.v1.trainable_variables()
                    if (
                        trainable
                        and self.trainable
                        and variable not in trainable_variables
                    ):
                        # A custom getter / variable scope overrode the
                        # trainable flag.
                        extra_trainable_vars = self._trainable_weights[
                            prev_len_trainable:
                        ]
                        self._trainable_weights = self._trainable_weights[
                            :prev_len_trainable
                        ]
                        self._non_trainable_weights += extra_trainable_vars
        return variable

    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

        Args:
          inputs: input tensor(s).
          *args: additional positional arguments to be passed to `self.call`.
          **kwargs: additional keyword arguments to be passed to `self.call`.
            **Note**: kwarg `scope` is reserved for use by the layer.

        Returns:
          Output tensor(s).

        Note:
          - If the layer's `call` method takes a `scope` keyword argument, this
            argument will be automatically set to the current variable scope.
          - If the layer's `call` method takes a `mask` argument (as some Keras
            layers do), its default value will be set to the mask generated
            for `inputs` by the previous layer (if `input` did come from
            a layer that generated a corresponding mask, i.e. if it came from
            a TF-Keras layer with masking support.

        Raises:
          ValueError: if the layer's `call` method returns None (an invalid
            value).
        """
        scope = kwargs.pop("scope", None)

        if self._keras_style:
            if scope is not None:
                raise ValueError(
                    "scope argument not allowed when keras style layers are "
                    "enabled, but saw: {}".format(scope)
                )
            return super().__call__(inputs, *args, **kwargs)

        self._set_scope(scope)

        if self.built:
            try:
                # Some classes which inherit from Layer do not use its
                # constructor, so rather than initializing to None we check for
                # an AttributeError.
                scope_context_manager = self._always_reuse_variable_scope
            except AttributeError:
                scope_context_manager = None

            if scope_context_manager is None:
                # From this point we will always set reuse=True, so create a
                # "final" variable scope with this setting. We avoid re-creating
                # variable scopes after this point as an optimization.
                scope_context_manager = tf.compat.v1.variable_scope(
                    self._scope, reuse=True, auxiliary_name_scope=False
                )

                # Do not cache variable scopes if Eager mode is enabled. If
                # Eager mode is enabled then we don't want to reuse scopes
                # because the cached scope might be from a FuncGraph or Eager
                # scope we are no longer in.
                if not tf.compat.v1.executing_eagerly_outside_functions():
                    self._always_reuse_variable_scope = scope_context_manager
        else:
            scope_context_manager = tf.compat.v1.variable_scope(
                self._scope, reuse=self._reuse, auxiliary_name_scope=False
            )

        with scope_context_manager as scope:
            self._current_scope = scope

            try:
                call_has_scope_arg = self._call_has_scope_arg
            except AttributeError:
                self._call_spec.arg_names = variable_scope_shim.fn_args(
                    self.call
                )
                self._call_has_scope_arg = "scope" in self._call_spec.arg_names
                call_has_scope_arg = self._call_has_scope_arg
            if call_has_scope_arg:
                kwargs["scope"] = scope

            # Actually call layer
            outputs = super().__call__(inputs, *args, **kwargs)

        if not tf.executing_eagerly():
            # Update global default collections.
            _add_elements_to_collection(
                self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS
            )
        return outputs

    def __deepcopy__(self, memo):
        no_copy = set(["_graph", "_thread_local", "_metrics_lock"])
        shallow_copy = set(["_scope", "_always_reuse_variable_scope"])
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            if k in no_copy:
                setattr(result, k, v)
            elif k in shallow_copy:
                setattr(result, k, copy.copy(v))
            elif base_layer.is_tensor_or_tensor_list(v):
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

    def __setattr__(self, value, name):
        # By-pass the automatic dependency tracking performed by the parent
        # Layer.
        super(tf.__internal__.tracking.Trackable, self).__setattr__(value, name)

    @property
    def _is_legacy_layer(self):
        """Used by keras to check compatibility. This should not be
        overridden."""
        return True


def _add_elements_to_collection(elements, collection_list):
    if tf.executing_eagerly():
        raise RuntimeError(
            "Using collections from Layers not supported in Eager "
            "mode. Tried to add %s to %s" % (elements, collection_list)
        )
    elements = tf.nest.flatten(elements)
    collection_list = tf.nest.flatten(collection_list)
    for name in collection_list:
        collection = tf.compat.v1.get_collection_ref(name)
        collection_set = {id(e) for e in collection}
        for element in elements:
            if id(element) not in collection_set:
                collection.append(element)

