# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utilities related to layer/model functionality."""

import copy
import functools
import re
import weakref

import numpy as np
import tensorflow.compat.v2 as tf

from tf_keras.src import initializers
from tf_keras.src.utils import io_utils

# isort: off
from tensorflow.python.util.tf_export import keras_export


@keras_export("keras.utils.get_source_inputs")
def get_source_inputs(tensor, layer=None, node_index=None):
    """Returns the list of input tensors necessary to compute `tensor`.

    Output will always be a list of tensors
    (potentially with 1 element).

    Args:
        tensor: The tensor to start from.
        layer: Origin layer of the tensor. Will be
            determined via tensor._keras_history if not provided.
        node_index: Origin node index of the tensor.

    Returns:
        List of input tensors.
    """
    if not hasattr(tensor, "_keras_history"):
        return tensor

    if layer is None or node_index:
        layer, node_index, _ = tensor._keras_history
    if not layer._inbound_nodes:
        return [tensor]
    else:
        node = layer._inbound_nodes[node_index]
        if node.is_input:
            # Reached an Input layer, stop recursion.
            return tf.nest.flatten(node.input_tensors)
        else:
            source_tensors = []
            for layer, node_index, _, tensor in node.iterate_inbound():
                previous_sources = get_source_inputs(tensor, layer, node_index)
                # Avoid input redundancy.
                for x in previous_sources:
                    if all(x is not t for t in source_tensors):
                        source_tensors.append(x)
            return source_tensors


def validate_string_arg(
    input_data,
    allowable_strings,
    layer_name,
    arg_name,
    allow_none=False,
    allow_callables=False,
):
    """Validates the correctness of a string-based arg."""
    if allow_none and input_data is None:
        return
    elif allow_callables and callable(input_data):
        return
    elif isinstance(input_data, str) and input_data in allowable_strings:
        return
    else:
        allowed_args = "`None`, " if allow_none else ""
        allowed_args += "a `Callable`, " if allow_callables else ""
        allowed_args += f"or one of the following values: {allowable_strings}"
        if allow_callables:
            callable_note = (
                f"If restoring a model and `{arg_name}` is a custom callable, "
                "please ensure the callable is registered as a custom object. "
                "See "
                "https://www.tensorflow.org/guide/keras/save_and_serialize"
                "#registering_the_custom_object for details. "
            )
        else:
            callable_note = ""
        raise ValueError(
            f"Unkown value for `{arg_name}` argument of layer {layer_name}. "
            f"{callable_note}Allowed values are: {allowed_args}. Received: "
            f"{input_data}"
        )


def count_params(weights):
    """Count the total number of scalars composing the weights.

    Args:
        weights: An iterable containing the weights on which to compute params

    Returns:
        The total number of scalars composing the weights
    """
    unique_weights = {id(w): w for w in weights}.values()
    # Ignore TrackableWeightHandlers, which will not have a shape defined.
    unique_weights = [w for w in unique_weights if hasattr(w, "shape")]
    weight_shapes = [w.shape.as_list() for w in unique_weights]
    standardized_weight_shapes = [
        [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
    ]
    return int(sum(np.prod(p) for p in standardized_weight_shapes))


def weight_memory_size(weights):
    """Calculate the memory footprint for weights based on their dtypes.

    Args:
        weights: An iterable contains the weights to compute weight size.

    Returns:
        The total memory size (in Bytes) of the weights.
    """
    unique_weights = {id(w): w for w in weights}.values()

    total_memory_size = 0
    for w in unique_weights:
        # Ignore TrackableWeightHandlers, which will not have a shape defined.
        if not hasattr(w, "shape"):
            continue
        elif None in w.shape.as_list():
            continue
        weight_shape = np.prod(w.shape.as_list())
        per_param_size = w.dtype.size
        total_memory_size += weight_shape * per_param_size
    return total_memory_size


def dtensor_variable_summary(weights):
    """Group and calculate DTensor based weights memory size.

    Since DTensor weights can be sharded across multiple device, the result
    will be grouped by the layout/sharding spec for the variables, so that
    the accurate per-device memory size can be calculated.

    Args:
        weights: An iterable contains the weights to compute weight size.

    Returns:
        total_weight_count, total_memory_size and per_sharing_spec_result which
        is a dict with normalized layout spec as key and tuple of weight count
        and weight size as value.
    """
    unique_weights = {id(w): w for w in weights}.values()
    total_weight_count = 0
    total_memory_size = 0
    per_sharing_spec_result = {}
    for w in unique_weights:
        # Ignore TrackableWeightHandlers, which will not have a shape defined.
        if not hasattr(w, "shape"):
            continue
        if not isinstance(w, tf.experimental.dtensor.DVariable):
            continue
        layout = w.layout
        # Remove all the duplication axis, and sort the column name.
        # 1D replicated and 2D replicated variable will still be fully
        # replicated, and [batch, model] sharding will have same memory
        # footprint as the [model, batch] layout.
        reduced_sharding_spec = list(sorted(set(layout.sharding_specs)))
        if tf.experimental.dtensor.UNSHARDED in reduced_sharding_spec:
            reduced_sharding_spec.remove(tf.experimental.dtensor.UNSHARDED)
        reduced_sharding_spec = tuple(reduced_sharding_spec)  # For dict key
        weight_count, memory_size = per_sharing_spec_result.get(
            reduced_sharding_spec, (0, 0)
        )
        reduced_weight_shape = np.prod(w.shape.as_list())
        per_param_size = w.dtype.size
        weight_count += reduced_weight_shape
        memory_size += reduced_weight_shape * per_param_size
        per_sharing_spec_result[reduced_sharding_spec] = (
            weight_count,
            memory_size,
        )
        total_weight_count += reduced_weight_shape
        total_memory_size += reduced_weight_shape * per_param_size
    return total_weight_count, total_memory_size, per_sharing_spec_result


def print_dtensor_variable_summary(model, print_fn, line_length):
    if getattr(model, "_layout_map", None) is not None:
        mesh = model._layout_map.get_default_mesh()
    elif hasattr(model, "distribute_strategy") and hasattr(
        model.distribute_strategy, "_mesh"
    ):
        mesh = model.distribute_strategy._mesh
    else:
        # Not running with DTensor
        mesh = None
    if mesh:
        (
            total_weight_count,
            total_memory_size,
            per_sharing_spec_result,
        ) = dtensor_variable_summary(model.weights)
        total_per_device_memory_size = 0
        for sharding_spec in sorted(per_sharing_spec_result.keys()):
            count, memory_size = per_sharing_spec_result[sharding_spec]
            if len(sharding_spec) == 0:
                print_fn(
                    f"{count} / {total_weight_count} params "
                    f"({readable_memory_size(memory_size)}) "
                    "are fully replicated"
                )
                per_device_size = memory_size
            else:
                sharding_factor = np.prod(
                    [mesh.dim_size(s) for s in sharding_spec]
                )
                per_device_size = memory_size / sharding_factor
                print_fn(
                    f"{count} / {total_weight_count} params "
                    f"({readable_memory_size(memory_size)}) are sharded based "
                    f"on spec '{sharding_spec}' and across {sharding_factor} "
                    f"devices."
                )
            total_per_device_memory_size += per_device_size
        print_fn(
            "Overall per device memory usage: "
            f"{readable_memory_size(total_per_device_memory_size)}"
        )
        print_fn(
            "Overall sharding factor: {:.2f}".format(
                total_memory_size / total_per_device_memory_size
            )
        )
        print_fn("_" * line_length)


def readable_memory_size(weight_memory_size):
    """Convert the weight memory size (Bytes) to a readable string."""
    units = ["Byte", "KB", "MB", "GB", "TB", "PB"]
    scale = 1024
    for unit in units:
        if weight_memory_size / scale < 1:
            return "{:.2f} {}".format(weight_memory_size, unit)
        else:
            weight_memory_size /= scale
    return "{:.2f} {}".format(weight_memory_size, units[-1])


def get_layer_index_bound_by_layer_name(model, layer_range=None):
    """Get the layer indexes from the model based on layer names.

    The layer indexes can be used to slice the model into sub models for
    display.

    Args:
        model: `tf.keras.Model` instance.
        layer_names: a list or tuple of 2 strings, the starting layer name and
            ending layer name (both inclusive) for the result. All layers will
            be included when `None` is provided.

    Returns:
        The index value of layer based on its unique name (layer_names).
        Output will be [first_layer_index, last_layer_index + 1].
    """
    if layer_range is not None:
        if len(layer_range) != 2:
            raise ValueError(
                "layer_range must be a list or tuple of length 2. Received: "
                f"layer_range = {layer_range} of length {len(layer_range)}"
            )
        if not isinstance(layer_range[0], str) or not isinstance(
            layer_range[1], str
        ):
            raise ValueError(
                "layer_range should contain string type only. "
                f"Received: {layer_range}"
            )
    else:
        return [0, len(model.layers)]

    lower_index = [
        idx
        for idx, layer in enumerate(model.layers)
        if re.match(layer_range[0], layer.name)
    ]
    upper_index = [
        idx
        for idx, layer in enumerate(model.layers)
        if re.match(layer_range[1], layer.name)
    ]

    if not lower_index or not upper_index:
        raise ValueError(
            "Passed layer_names do not match the layer names in the model. "
            f"Received: {layer_range}"
        )

    if min(lower_index) > max(upper_index):
        return [min(upper_index), max(lower_index) + 1]
    return [min(lower_index), max(upper_index) + 1]


def print_summary(
    model,
    line_length=None,
    positions=None,
    print_fn=None,
    expand_nested=False,
    show_trainable=False,
    layer_range=None,
):
    """Prints a summary of a model.

    Args:
        model: TF-Keras model instance.
        line_length: Total length of printed lines
            (e.g. set this to adapt the display to different
            terminal window sizes).
        positions: Relative or absolute positions of log elements in each line.
            If not provided, defaults to `[0.3, 0.6, 0.70, 1.]`.
        print_fn: Print function to use.
            It will be called on each line of the summary.
            You can set it to a custom function
            in order to capture the string summary.
            When `None`, uses `print` (prints to stdout).
            Defaults to `None`.
        expand_nested: Whether to expand the nested models.
            Defaults to `False`.
        show_trainable: Whether to show if a layer is trainable.
            Defaults to `False`.
        layer_range: List or tuple containing two strings,
            the starting layer name and ending layer name (both inclusive),
            indicating the range of layers to be printed in the summary. The
            strings could also be regexes instead of an exact name. In this
             case, the starting layer will be the first layer that matches
            `layer_range[0]` and the ending layer will be the last element that
            matches `layer_range[1]`. By default (`None`) all
            layers in the model are included in the summary.
    """
    if print_fn is None:
        print_fn = io_utils.print_msg

    if model.__class__.__name__ == "Sequential":
        sequential_like = True
    elif not model._is_graph_network:
        # We treat subclassed models as a simple sequence of layers, for logging
        # purposes.
        sequential_like = True
    else:
        sequential_like = True
        nodes_by_depth = model._nodes_by_depth.values()
        nodes = []
        for v in nodes_by_depth:
            if (len(v) > 1) or (
                len(v) == 1 and len(tf.nest.flatten(v[0].keras_inputs)) > 1
            ):
                # if the model has multiple nodes
                # or if the nodes have multiple inbound_layers
                # the model is no longer sequential
                sequential_like = False
                break
            nodes += v
        if sequential_like:
            # search for shared layers
            for layer in model.layers:
                flag = False
                for node in layer._inbound_nodes:
                    if node in nodes:
                        if flag:
                            sequential_like = False
                            break
                        else:
                            flag = True
                if not sequential_like:
                    break

    if sequential_like:
        line_length = line_length or 65
        positions = positions or [0.45, 0.85, 1.0]
        if positions[-1] <= 1:
            positions = [int(line_length * p) for p in positions]
        # header names for the different log elements
        to_display = ["Layer (type)", "Output Shape", "Param #"]
    else:
        line_length = line_length or 98
        positions = positions or [0.3, 0.6, 0.70, 1.0]
        if positions[-1] <= 1:
            positions = [int(line_length * p) for p in positions]
        # header names for the different log elements
        to_display = ["Layer (type)", "Output Shape", "Param #", "Connected to"]
        relevant_nodes = []
        for v in model._nodes_by_depth.values():
            relevant_nodes += v

    if show_trainable:
        line_length += 11
        positions.append(line_length)
        to_display.append("Trainable")

    layer_range = get_layer_index_bound_by_layer_name(model, layer_range)

    def print_row(fields, positions, nested_level=0):
        left_to_print = [str(x) for x in fields]
        while any(left_to_print):
            line = ""
            for col in range(len(left_to_print)):
                if col > 0:
                    start_pos = positions[col - 1]
                else:
                    start_pos = 0
                end_pos = positions[col]
                # Leave room for 2 spaces to delineate columns
                # we don't need any if we are printing the last column
                space = 2 if col != len(positions) - 1 else 0
                cutoff = end_pos - start_pos - space
                # Except for last col, offset by one to align the start of col
                if col != len(positions) - 1:
                    cutoff -= 1
                if col == 0:
                    cutoff -= nested_level
                fit_into_line = left_to_print[col][:cutoff]
                # For nicer formatting we line-break on seeing end of
                # tuple/dict etc.
                line_break_conditions = ("),", "},", "],", "',")
                candidate_cutoffs = [
                    fit_into_line.find(x) + len(x)
                    for x in line_break_conditions
                    if fit_into_line.find(x) >= 0
                ]
                if candidate_cutoffs:
                    cutoff = min(candidate_cutoffs)
                    fit_into_line = fit_into_line[:cutoff]

                if col == 0:
                    line += "|" * nested_level + " "
                line += fit_into_line
                line += " " * space if space else ""
                left_to_print[col] = left_to_print[col][cutoff:]

                # Pad out to the next position
                # Make space for nested_level for last column
                if nested_level and col == len(positions) - 1:
                    line += " " * (positions[col] - len(line) - nested_level)
                else:
                    line += " " * (positions[col] - len(line))
            line += "|" * nested_level
            print_fn(line)

    print_fn(f'Model: "{model.name}"')
    print_fn("_" * line_length)
    print_row(to_display, positions)
    print_fn("=" * line_length)

    def print_layer_summary(layer, nested_level=0):
        """Prints a summary for a single layer.

        Args:
            layer: target layer.
            nested_level: level of nesting of the layer inside its parent layer
              (e.g. 0 for a top-level layer, 1 for a nested layer).
        """
        try:
            output_shape = layer.output_shape
        except AttributeError:
            output_shape = "multiple"
        except RuntimeError:  # output_shape unknown in Eager mode.
            output_shape = "?"
        name = layer.name
        cls_name = layer.__class__.__name__
        if not layer.built and not getattr(layer, "_is_graph_network", False):
            # If a subclassed model has a layer that is not called in
            # Model.call, the layer will not be built and we cannot call
            # layer.count_params().
            params = "0 (unused)"
        else:
            params = layer.count_params()
        fields = [name + " (" + cls_name + ")", output_shape, params]

        if show_trainable:
            fields.append("Y" if layer.trainable else "N")

        print_row(fields, positions, nested_level)

    def print_layer_summary_with_connections(layer, nested_level=0):
        """Prints a summary for a single layer (including its connections).

        Args:
            layer: target layer.
            nested_level: level of nesting of the layer inside its parent layer
              (e.g. 0 for a top-level layer, 1 for a nested layer).
        """
        try:
            output_shape = layer.output_shape
        except AttributeError:
            output_shape = "multiple"
        connections = []
        for node in layer._inbound_nodes:
            if relevant_nodes and node not in relevant_nodes:
                # node is not part of the current network
                continue

            for (
                inbound_layer,
                node_index,
                tensor_index,
                _,
            ) in node.iterate_inbound():
                connections.append(
                    f"{inbound_layer.name}[{node_index}][{tensor_index}]"
                )

        name = layer.name
        cls_name = layer.__class__.__name__
        fields = [
            name + " (" + cls_name + ")",
            output_shape,
            layer.count_params(),
            connections,
        ]

        if show_trainable:
            fields.append("Y" if layer.trainable else "N")

        print_row(fields, positions, nested_level)

    def print_layer(layer, nested_level=0, is_nested_last=False):
        if sequential_like:
            print_layer_summary(layer, nested_level)
        else:
            print_layer_summary_with_connections(layer, nested_level)

        if expand_nested and hasattr(layer, "layers") and layer.layers:
            print_fn(
                "|" * (nested_level + 1)
                + "¯" * (line_length - 2 * nested_level - 2)
                + "|" * (nested_level + 1)
            )

            nested_layer = layer.layers
            is_nested_last = False
            for i in range(len(nested_layer)):
                if i == len(nested_layer) - 1:
                    is_nested_last = True
                print_layer(nested_layer[i], nested_level + 1, is_nested_last)

            print_fn(
                "|" * nested_level
                + "¯" * (line_length - 2 * nested_level)
                + "|" * nested_level
            )

        if not is_nested_last:
            print_fn(
                "|" * nested_level
                + " " * (line_length - 2 * nested_level)
                + "|" * nested_level
            )

    for layer in model.layers[layer_range[0] : layer_range[1]]:
        print_layer(layer)
    print_fn("=" * line_length)

    if hasattr(model, "_collected_trainable_weights"):
        trainable_count = count_params(model._collected_trainable_weights)
        trainable_memory_size = weight_memory_size(
            model._collected_trainable_weights
        )
    else:
        trainable_count = count_params(model.trainable_weights)
        trainable_memory_size = weight_memory_size(model.trainable_weights)

    non_trainable_count = count_params(model.non_trainable_weights)
    non_trainable_memory_size = weight_memory_size(model.non_trainable_weights)

    total_memory_size = trainable_memory_size + non_trainable_memory_size

    print_fn(
        f"Total params: {trainable_count + non_trainable_count} "
        f"({readable_memory_size(total_memory_size)})"
    )
    print_fn(
        f"Trainable params: {trainable_count} "
        f"({readable_memory_size(trainable_memory_size)})"
    )
    print_fn(
        f"Non-trainable params: {non_trainable_count} "
        f"({readable_memory_size(non_trainable_memory_size)})"
    )
    print_fn("_" * line_length)

    print_dtensor_variable_summary(model, print_fn, line_length)


def convert_dense_weights_data_format(
    dense, previous_feature_map_shape, target_data_format="channels_first"
):
    """Utility useful when changing a convnet's `data_format`.

    When porting the weights of a convnet from one data format to the other,
    if the convnet includes a `Flatten` layer
    (applied to the last convolutional feature map)
    followed by a `Dense` layer, the weights of that `Dense` layer
    should be updated to reflect the new dimension ordering.

    Args:
        dense: The target `Dense` layer.
        previous_feature_map_shape: A shape tuple of 3 integers,
            e.g. `(512, 7, 7)`. The shape of the convolutional
            feature map right before the `Flatten` layer that
            came before the target `Dense` layer.
        target_data_format: One of "channels_last", "channels_first".
            Set it "channels_last"
            if converting a "channels_first" model to "channels_last",
            or reciprocally.
    """
    assert target_data_format in {"channels_last", "channels_first"}
    kernel, bias = dense.get_weights()
    for i in range(kernel.shape[1]):
        if target_data_format == "channels_first":
            c, h, w = previous_feature_map_shape
            original_fm_shape = (h, w, c)
            ki = kernel[:, i].reshape(original_fm_shape)
            ki = np.transpose(ki, (2, 0, 1))  # last -> first
        else:
            h, w, c = previous_feature_map_shape
            original_fm_shape = (c, h, w)
            ki = kernel[:, i].reshape(original_fm_shape)
            ki = np.transpose(ki, (1, 2, 0))  # first -> last
        kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
    dense.set_weights([kernel, bias])


def is_builtin_layer(layer):
    if not getattr(layer, "_keras_api_names", None):
        return False

    # Subclasses of `Layer` that are not exported inherit the export name
    # of the base layer class.
    return layer._keras_api_names != (
        "keras.layers.Layer",
    ) and layer._keras_api_names_v1 != ("keras.layers.Layer",)


def cached_per_instance(f):
    """Lightweight decorator for caching lazily constructed properties.

    When to use:
    This decorator provides simple caching with minimal overhead. It is designed
    for properties which are expensive to compute and static over the life of a
    class instance, and provides no mechanism for cache invalidation. Thus it is
    best suited for lazily exposing derived properties of other static data.

    For classes with custom getattr / setattr behavior (such as trackable
    objects), storing cache results as object attributes is not performant.
    Instead, a specialized cache can significantly reduce property lookup
    overhead. (While still allowing the decorated property to be lazily
    computed.) Consider the following class:

    ```
    class MyClass:
      def __setattr__(self, key, value):
        # Some expensive class specific code
        # ...
        # ...

        super(MyClass, self).__setattr__(key, value)

      @property
      def thing(self):
        # `thing` is expensive to compute (and may not even be requested), so we
        # want to lazily compute it and then cache it.
        output = getattr(self, '_thing', None)
        if output is None:
          self._thing = output = compute_thing(self)
        return output
    ```

    It's also worth noting that ANY overriding of __setattr__, even something as
    simple as:
    ```
      def __setattr__(self, key, value):
        super(MyClass, self).__setattr__(key, value)
    ```

    Slows down attribute assignment by nearly 10x.

    By contrast, replacing the definition of `thing` with the following
    sidesteps the expensive __setattr__ altogether:

    '''
    @property
    @tracking.cached_per_instance
    def thing(self):
      # `thing` is expensive to compute (and may not even be requested), so we
      # want to lazily compute it and then cache it.
      return compute_thing(self)
    '''

    Performance:
    The overhead for this decorator is ~0.4 us / call. A much lower overhead
    implementation (~0.085 us / call) can be achieved by using a custom dict
    type:

    ```
    def dict_based_cache(f):
      class Cache(dict):
        __slots__ = ()
        def __missing__(self, key):
          self[key] = output = f(key)
          return output

      return property(Cache().__getitem__)
    ```

    However, that implementation holds class instances as keys, and as a result
    blocks garbage collection. (And modifying it to use weakref's as keys raises
    the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
    implementation below turns out to be more prudent.

    Args:
      f: The function to cache.

    Returns:
      f decorated with simple caching behavior.
    """

    cache = weakref.WeakKeyDictionary()

    @functools.wraps(f)
    def wrapped(item):
        output = cache.get(item)
        if output is None:
            cache[item] = output = f(item)
        return output

    wrapped.cache = cache
    return wrapped


def filter_empty_layer_containers(layer_list):
    """Filter out empty Layer-like containers and uniquify."""
    # TODO(b/130381733): Make this an attribute in base_layer.Layer.
    existing = set()
    to_visit = layer_list[::-1]
    while to_visit:
        obj = to_visit.pop()
        if id(obj) in existing:
            continue
        existing.add(id(obj))
        if hasattr(obj, "_is_layer") and not isinstance(obj, type):
            yield obj
        else:
            sub_layers = getattr(obj, "layers", None) or []

            # Trackable data structures will not show up in ".layers" lists, but
            # the layers they contain will.
            to_visit.extend(sub_layers[::-1])


class CallFunctionSpec:
    """Caches the spec and provides utilities for handling call function
    args."""

    def __init__(self, full_argspec):
        """Initialies a `CallFunctionSpec`.

        Args:
          full_argspec: the FullArgSpec of a call function of a layer.
        """
        self._full_argspec = full_argspec

        self._arg_names = list(self._full_argspec.args)
        # Scrub `self` that appears if a decorator was applied.
        if self._arg_names and self._arg_names[0] == "self":
            self._arg_names = self._arg_names[1:]
        self._arg_names += self._full_argspec.kwonlyargs or []

        call_accepts_kwargs = self._full_argspec.varkw is not None
        self._expects_training_arg = (
            "training" in self._arg_names or call_accepts_kwargs
        )
        self._expects_mask_arg = (
            "mask" in self._arg_names or call_accepts_kwargs
        )

        call_fn_defaults = self._full_argspec.defaults or []
        defaults = dict()
        # The call arg defaults are an n-tuple of the last n elements of the
        # args list. (n = # of elements that have a default argument)
        for i in range(-1 * len(call_fn_defaults), 0):
            defaults[self._arg_names[i]] = call_fn_defaults[i]
        # The default training arg will be any (non-None) default specified in
        # the method signature, or None if no value is specified.
        defaults.update(self._full_argspec.kwonlydefaults or {})
        self._default_training_arg = defaults.get("training")

    @property
    def full_argspec(self):
        """Returns the FullArgSpec of the call function."""
        return self._full_argspec

    @property
    def arg_names(self):
        """List of names of args and kwonlyargs."""
        # `arg_names` is not accurate if the layer has variable positional args.
        return self._arg_names

    @arg_names.setter
    def arg_names(self, value):
        self._arg_names = value

    @property
    @cached_per_instance
    def arg_positions(self):
        """Returns a dict mapping arg names to their index positions."""
        # `arg_positions` is not accurate if the layer has variable positional
        # args.
        call_fn_arg_positions = dict()
        for pos, arg in enumerate(self._arg_names):
            call_fn_arg_positions[arg] = pos
        return call_fn_arg_positions

    @property
    def expects_training_arg(self):
        """Whether the call function uses 'training' as a parameter."""
        return self._expects_training_arg

    @expects_training_arg.setter
    def expects_training_arg(self, value):
        self._expects_training_arg = value

    @property
    def expects_mask_arg(self):
        """Whether the call function uses `mask` as a parameter."""
        return self._expects_mask_arg

    @expects_mask_arg.setter
    def expects_mask_arg(self, value):
        self._expects_mask_arg = value

    @property
    def default_training_arg(self):
        """The default value given to the "training" argument."""
        return self._default_training_arg

    def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
        """Returns true if argument is present in `args` or `kwargs`.

        Args:
          arg_name: String name of the argument to find.
          args: Tuple of args passed to the call function.
          kwargs: Dictionary of kwargs  passed to the call function.
          inputs_in_args: Whether the input argument (the first argument in the
            call function) is included in `args`. Defaults to `False`.

        Returns:
          True if argument with `arg_name` is present in `args` or `kwargs`.
        """
        # Performance optimization: do no work in most common case.
        if not args and not kwargs:
            return False

        if arg_name in kwargs:
            return True
        call_fn_args = self._arg_names
        if not inputs_in_args:
            # Ignore `inputs` arg.
            call_fn_args = call_fn_args[1:]
        return arg_name in dict(zip(call_fn_args, args))

    def get_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
        """Retrieves the value for the argument with name `arg_name`.

        Args:
          arg_name: String name of the argument to find.
          args: Tuple of args passed to the call function.
          kwargs: Dictionary of kwargs  passed to the call function.
          inputs_in_args: Whether the input argument (the first argument in the
            call function) is included in `args`. Defaults to `False`.

        Returns:
          The value of the argument with name `arg_name`, extracted from `args`
          or `kwargs`.

        Raises:
          KeyError if the value of `arg_name` cannot be found.
        """
        if arg_name in kwargs:
            return kwargs[arg_name]
        call_fn_args = self._arg_names
        if not inputs_in_args:
            # Ignore `inputs` arg.
            call_fn_args = call_fn_args[1:]
        args_dict = dict(zip(call_fn_args, args))
        return args_dict[arg_name]

    def set_arg_value(
        self,
        arg_name,
        new_value,
        args,
        kwargs,
        inputs_in_args=False,
        pop_kwarg_if_none=False,
    ):
        """Sets the value of an argument into the given args/kwargs.

        Args:
          arg_name: String name of the argument to find.
          new_value: New value to give to the argument.
          args: Tuple of args passed to the call function.
          kwargs: Dictionary of kwargs  passed to the call function.
          inputs_in_args: Whether the input argument (the first argument in the
            call function) is included in `args`. Defaults to `False`.
          pop_kwarg_if_none: If the new value is `None`, and this is `True`,
            then the argument is deleted from `kwargs`.

        Returns:
          The updated `(args, kwargs)`.
        """
        if self.full_argspec.varargs:
            try:
                arg_pos = self.full_argspec.args.index(arg_name)
                if self.full_argspec.args[0] == "self":
                    arg_pos -= 1
            except ValueError:
                arg_pos = None
        else:
            arg_pos = self.arg_positions.get(arg_name, None)

        if arg_pos is not None:
            if not inputs_in_args:
                # Ignore `inputs` arg.
                arg_pos = arg_pos - 1
            if len(args) > arg_pos:
                args = list(args)
                args[arg_pos] = new_value
                return tuple(args), kwargs
        if new_value is None and pop_kwarg_if_none:
            kwargs.pop(arg_name, None)
        else:
            kwargs[arg_name] = new_value
        return args, kwargs

    def split_out_first_arg(self, args, kwargs):
        """Splits (args, kwargs) into (inputs, args, kwargs)."""
        # Grab the argument corresponding to the first argument in the
        # layer's `call` method spec. This will either be the first positional
        # argument, or it will be provided as a keyword argument.
        if args:
            inputs = args[0]
            args = args[1:]
        elif self._arg_names[0] in kwargs:
            kwargs = copy.copy(kwargs)
            inputs = kwargs.pop(self._arg_names[0])
        else:
            raise ValueError(
                "The first argument to `Layer.call` must always be passed."
            )
        return inputs, args, kwargs


@keras_export("keras.utils.warmstart_embedding_matrix")
def warmstart_embedding_matrix(
    base_vocabulary,
    new_vocabulary,
    base_embeddings,
    new_embeddings_initializer="uniform",
):
    """Warm start embedding matrix with changing vocab.

    This util can be used to warmstart the embedding layer matrix when
    vocabulary changes between previously saved checkpoint and model.
    Vocabulary change could mean, the size of the new vocab is different or the
    vocabulary is reshuffled or new vocabulary has been added to old vocabulary.
    If the vocabulary size changes, size of the embedding layer matrix also
    changes. This util remaps the old vocabulary embeddings to the new embedding
    layer matrix.

    Example:
    Here is an example that demonstrates how to use the
    `warmstart_embedding_matrix` util.
    >>> import tf_keras.src as keras
    >>> vocab_base = tf.convert_to_tensor(["unk", "a", "b", "c"])
    >>> vocab_new = tf.convert_to_tensor(
    ...        ["unk", "unk", "a", "b", "c", "d", "e"])
    >>> vectorized_vocab_base = np.random.rand(vocab_base.shape[0], 3)
    >>> vectorized_vocab_new = np.random.rand(vocab_new.shape[0], 3)
    >>> warmstarted_embedding_matrix = warmstart_embedding_matrix(
    ...       base_vocabulary=vocab_base,
    ...       new_vocabulary=vocab_new,
    ...       base_embeddings=vectorized_vocab_base,
    ...       new_embeddings_initializer=keras.initializers.Constant(
    ...         vectorized_vocab_new))

    Here is an example that demonstrates how to get vocabulary and embedding
    weights from layers, use the `warmstart_embedding_matrix` util to remap the
    layer embeddings and continue with model training.
    ```
    # get old and new vocabulary by using layer.get_vocabulary()
    # for example assume TextVectorization layer is used
    base_vocabulary = old_text_vectorization_layer.get_vocabulary()
    new_vocabulary = new_text_vectorization_layer.get_vocabulary()
    # get previous embedding layer weights
    embedding_weights_base = model.get_layer('embedding').get_weights()[0]
    warmstarted_embedding = keras.utils.warmstart_embedding_matrix(
                                  base_vocabulary,
                                  new_vocabulary,
                                  base_embeddings=embedding_weights_base,
                                  new_embeddings_initializer="uniform")
    updated_embedding_variable = tf.Variable(warmstarted_embedding)

    # update embedding layer weights
    model.layers[1].embeddings = updated_embedding_variable
    model.fit(..)
    # continue with model training

    ```

    Args:
        base_vocabulary: The list of vocabulary terms that
          the preexisting embedding matrix `base_embeddings` represents.
          It can be either a 1D array/tensor or a tuple/list of vocabulary
          terms (strings), or a path to a vocabulary text file. If passing a
           file path, the file should contain one line per term in the
           vocabulary.
        new_vocabulary: The list of vocabulary terms for the new vocabulary
           (same format as above).
        base_embeddings: NumPy array or tensor representing the preexisting
          embedding matrix.
        new_embeddings_initializer: Initializer for embedding vectors for
          previously unseen terms to be added to the new embedding matrix (see
          `keras.initializers`). new_embedding matrix
          needs to be specified with "constant" initializer.
          matrix. None means "uniform". Default value is None.

    Returns:
      tf.tensor of remapped embedding layer matrix

    """
    # convert vocab to list
    base_vocabulary = convert_vocab_to_list(base_vocabulary)
    new_vocabulary = convert_vocab_to_list(new_vocabulary)

    # Initialize the new embedding layer matrix
    new_embeddings_initializer = initializers.get(new_embeddings_initializer)
    new_embedding = new_embeddings_initializer(
        shape=(len(new_vocabulary), base_embeddings.shape[1]),
        dtype=base_embeddings.dtype,
    )

    # create mapping dict {vocab:index}
    base_vocabulary_dict = dict(
        zip(base_vocabulary, range(len(base_vocabulary)))
    )

    indices_base_vocabulary = []
    indices_new_vocabulary = []
    for index, key in enumerate(new_vocabulary):
        if key in base_vocabulary_dict:
            indices_base_vocabulary.append(base_vocabulary_dict[key])
            indices_new_vocabulary.append(int(index))

    # update embedding matrix
    if indices_base_vocabulary:
        values_to_update = tf.gather(base_embeddings, indices_base_vocabulary)
        new_embedding = tf.tensor_scatter_nd_update(
            new_embedding,
            tf.expand_dims(indices_new_vocabulary, axis=1),
            values_to_update,
        )
    return new_embedding


def convert_vocab_to_list(vocab):
    """Convert input vacabulary to list."""
    vocab_list = []
    if tf.is_tensor(vocab):
        vocab_list = list(vocab.numpy())
    elif isinstance(vocab, (np.ndarray, tuple, list)):
        vocab_list = list(vocab)
    elif isinstance(vocab, str):
        if not tf.io.gfile.exists(vocab):
            raise ValueError(f"Vocabulary file {vocab} does not exist.")
        with tf.io.gfile.GFile(vocab, "r") as vocabulary_file:
            vocab_list = vocabulary_file.read().splitlines()
    else:
        raise ValueError(
            "Vocabulary is expected to be either a NumPy array, "
            "list, 1D tensor or a vocabulary text file. Instead type "
            f"{type(vocab)} was received."
        )
    if len(vocab_list) == 0:
        raise ValueError(
            "Vocabulary is expected to be either a NumPy array, "
            "list, 1D tensor or a vocabulary text file with at least one token."
            " Received 0 instead."
        )
    return vocab_list

