# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""

import copy
import os

import tensorflow.compat.v2 as tf

import tf_keras.src as keras
from tf_keras.src import backend
from tf_keras.src import losses
from tf_keras.src import optimizers
from tf_keras.src.engine import base_layer_utils
from tf_keras.src.optimizers import optimizer_v1
from tf_keras.src.saving.legacy import serialization
from tf_keras.src.utils import version_utils
from tf_keras.src.utils.io_utils import ask_to_proceed_with_overwrite

# isort: off
from tensorflow.python.platform import tf_logging as logging


def extract_model_metrics(model):
    """Convert metrics from a TF-Keras model `compile` API to dictionary.

    This is used for converting TF-Keras models to Estimators and SavedModels.

    Args:
      model: A `tf.keras.Model` object.

    Returns:
      Dictionary mapping metric names to metric instances. May return `None` if
      the model does not contain any metrics.
    """
    if getattr(model, "_compile_metrics", None):
        # TODO(psv/kathywu): use this implementation in model to estimator flow.
        # We are not using model.metrics here because we want to exclude the
        # metrics added using `add_metric` API.
        return {m.name: m for m in model._compile_metric_functions}
    return None


def model_call_inputs(model, keep_original_batch_size=False):
    """Inspect model to get its input signature.

    The model's input signature is a list with a single (possibly-nested)
    object. This is due to the Keras-enforced restriction that tensor inputs
    must be passed in as the first argument.

    For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
    will have input signature:
    [{'feature1': TensorSpec, 'feature2': TensorSpec}]

    Args:
      model: TF-Keras Model object.
      keep_original_batch_size: A boolean indicating whether we want to keep
        using the original batch size or set it to None. Default is `False`,
        which means that the batch dim of the returned input signature will
        always be set to `None`.

    Returns:
      A tuple containing `(args, kwargs)` TensorSpecs of the model call function
      inputs.
      `kwargs` does not contain the `training` argument.
    """
    input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
    if input_specs is None:
        return None, None
    input_specs = _enforce_names_consistency(input_specs)
    return input_specs


def raise_model_input_error(model):
    if isinstance(model, keras.models.Sequential):
        raise ValueError(
            f"Model {model} cannot be saved because the input shape is not "
            "available. Please specify an input shape either by calling "
            "`build(input_shape)` directly, or by calling the model on actual "
            "data using `Model()`, `Model.fit()`, or `Model.predict()`."
        )

    # If the model is not a `Sequential`, it is intended to be a subclassed
    # model.
    raise ValueError(
        f"Model {model} cannot be saved either because the input shape is not "
        "available or because the forward pass of the model is not defined."
        "To define a forward pass, please override `Model.call()`. To specify "
        "an input shape, either call `build(input_shape)` directly, or call "
        "the model on actual data using `Model()`, `Model.fit()`, or "
        "`Model.predict()`. If you have a custom training step, please make "
        "sure to invoke the forward pass in train step through "
        "`Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`."
    )


def trace_model_call(model, input_signature=None):
    """Trace the model call to create a tf.function for exporting a TF-Keras
    model.

    Args:
      model: A TF-Keras model.
      input_signature: optional, a list of tf.TensorSpec objects specifying the
        inputs to the model.

    Returns:
      A tf.function wrapping the model's call function with input signatures
      set.

    Raises:
      ValueError: if input signature cannot be inferred from the model.
    """
    if input_signature is None:
        if isinstance(model.call, tf.__internal__.function.Function):
            input_signature = model.call.input_signature

    if input_signature:
        model_args = input_signature
        model_kwargs = {}
    else:
        model_args, model_kwargs = model_call_inputs(model)

        if model_args is None:
            raise_model_input_error(model)

    @tf.function
    def _wrapped_model(*args, **kwargs):
        """A concrete tf.function that wraps the model's call function."""
        (args, kwargs,) = model._call_spec.set_arg_value(
            "training", False, args, kwargs, inputs_in_args=True
        )

        with base_layer_utils.call_context().enter(
            model, inputs=None, build_graph=False, training=False, saving=True
        ):
            outputs = model(*args, **kwargs)

        # Outputs always has to be a flat dict.
        output_names = model.output_names  # Functional Model.
        if output_names is None:  # Subclassed Model.
            from tf_keras.src.engine import compile_utils

            output_names = compile_utils.create_pseudo_output_names(outputs)
        outputs = tf.nest.flatten(outputs)
        return {name: output for name, output in zip(output_names, outputs)}

    return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)


def model_metadata(model, include_optimizer=True, require_config=True):
    """Returns a dictionary containing the model metadata."""
    from tf_keras.src import __version__ as keras_version
    from tf_keras.src.optimizers.legacy import optimizer_v2

    model_config = {"class_name": model.__class__.__name__}
    try:
        model_config["config"] = model.get_config()
    except NotImplementedError as e:
        if require_config:
            raise e

    metadata = dict(
        keras_version=str(keras_version),
        backend=backend.backend(),
        model_config=model_config,
    )
    if model.optimizer and include_optimizer:
        if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
            logging.warning(
                "TensorFlow optimizers do not "
                "make it possible to access "
                "optimizer attributes or optimizer state "
                "after instantiation. "
                "As a result, we cannot save the optimizer "
                "as part of the model save file. "
                "You will have to compile your model again after loading it. "
                "Prefer using a TF-Keras optimizer instead "
                "(see keras.io/optimizers)."
            )
        elif model._compile_was_called:
            training_config = model._get_compile_args(user_metrics=False)
            training_config.pop("optimizer", None)  # Handled separately.
            metadata["training_config"] = _serialize_nested_config(
                training_config
            )
            if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
                raise NotImplementedError(
                    "Optimizers loaded from a SavedModel cannot be saved. "
                    "If you are calling `model.save` or "
                    "`tf.keras.models.save_model`, "
                    "please set the `include_optimizer` option to `False`. For "
                    "`tf.saved_model.save`, "
                    "delete the optimizer from the model."
                )
            else:
                optimizer_config = {
                    "class_name": keras.utils.get_registered_name(
                        model.optimizer.__class__
                    ),
                    "config": model.optimizer.get_config(),
                }
            metadata["training_config"]["optimizer_config"] = optimizer_config
    return metadata


def should_overwrite(filepath, overwrite):
    """Returns whether the filepath should be overwritten."""
    # If file exists and should not be overwritten.
    if not overwrite and os.path.isfile(filepath):
        return ask_to_proceed_with_overwrite(filepath)
    return True


def compile_args_from_training_config(training_config, custom_objects=None):
    """Return model.compile arguments from training config."""
    if custom_objects is None:
        custom_objects = {}

    with keras.utils.CustomObjectScope(custom_objects):
        optimizer_config = training_config["optimizer_config"]
        optimizer = optimizers.deserialize(optimizer_config)

        # Recover losses.
        loss = None
        loss_config = training_config.get("loss", None)
        if loss_config is not None:
            loss = _deserialize_nested_config(losses.deserialize, loss_config)

        # Recover metrics.
        metrics = None
        metrics_config = training_config.get("metrics", None)
        if metrics_config is not None:
            metrics = _deserialize_nested_config(
                _deserialize_metric, metrics_config
            )

        # Recover weighted metrics.
        weighted_metrics = None
        weighted_metrics_config = training_config.get("weighted_metrics", None)
        if weighted_metrics_config is not None:
            weighted_metrics = _deserialize_nested_config(
                _deserialize_metric, weighted_metrics_config
            )

        sample_weight_mode = (
            training_config["sample_weight_mode"]
            if hasattr(training_config, "sample_weight_mode")
            else None
        )
        loss_weights = training_config["loss_weights"]

    return dict(
        optimizer=optimizer,
        loss=loss,
        metrics=metrics,
        weighted_metrics=weighted_metrics,
        loss_weights=loss_weights,
        sample_weight_mode=sample_weight_mode,
    )


def _deserialize_nested_config(deserialize_fn, config):
    """Deserializes arbitrary TF-Keras `config` using `deserialize_fn`."""

    def _is_single_object(obj):
        if isinstance(obj, dict) and "class_name" in obj:
            return True  # Serialized TF-Keras object.
        if isinstance(obj, str):
            return True  # Serialized function or string.
        return False

    if config is None:
        return None
    if _is_single_object(config):
        return deserialize_fn(config)
    elif isinstance(config, dict):
        return {
            k: _deserialize_nested_config(deserialize_fn, v)
            for k, v in config.items()
        }
    elif isinstance(config, (tuple, list)):
        return [
            _deserialize_nested_config(deserialize_fn, obj) for obj in config
        ]

    raise ValueError(
        "Saved configuration not understood. Configuration should be a "
        f"dictionary, string, tuple or list. Received: config={config}."
    )


def _serialize_nested_config(config):
    """Serialized a nested structure of TF-Keras objects."""

    def _serialize_fn(obj):
        if callable(obj):
            return serialization.serialize_keras_object(obj)
        return obj

    return tf.nest.map_structure(_serialize_fn, config)


def _deserialize_metric(metric_config):
    """Deserialize metrics, leaving special strings untouched."""
    from tf_keras.src import metrics as metrics_module

    if metric_config in ["accuracy", "acc", "crossentropy", "ce"]:
        # Do not deserialize accuracy and cross-entropy strings as we have
        # special case handling for these in compile, based on model output
        # shape.
        return metric_config
    return metrics_module.deserialize(metric_config)


def _enforce_names_consistency(specs):
    """Enforces that either all specs have names or none do."""

    def _has_name(spec):
        return spec is None or (hasattr(spec, "name") and spec.name is not None)

    def _clear_name(spec):
        spec = copy.deepcopy(spec)
        if hasattr(spec, "name"):
            spec._name = None
        return spec

    flat_specs = tf.nest.flatten(specs)
    name_inconsistency = any(_has_name(s) for s in flat_specs) and not all(
        _has_name(s) for s in flat_specs
    )

    if name_inconsistency:
        specs = tf.nest.map_structure(_clear_name, specs)
    return specs


def try_build_compiled_arguments(model):
    if (
        not version_utils.is_v1_layer_or_model(model)
        and model.outputs is not None
    ):
        try:
            if not model.compiled_loss.built:
                model.compiled_loss.build(model.outputs)
            if not model.compiled_metrics.built:
                model.compiled_metrics.build(model.outputs, model.outputs)
        except:  # noqa: E722
            logging.warning(
                "Compiled the loaded model, but the compiled metrics have "
                "yet to be built. `model.compile_metrics` will be empty "
                "until you train or evaluate the model."
            )


def is_hdf5_filepath(filepath):
    return (
        filepath.endswith(".h5")
        or filepath.endswith(".keras")
        or filepath.endswith(".hdf5")
    )

