# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in optimizer classes.

For more examples see the base class `tf.keras.optimizers.Optimizer`.
"""

# Imports needed for deserialization.

import platform
import warnings

import tensorflow.compat.v2 as tf
from absl import logging

from tf_keras.src import backend
from tf_keras.src.optimizers import adadelta
from tf_keras.src.optimizers import adafactor
from tf_keras.src.optimizers import adagrad
from tf_keras.src.optimizers import adam
from tf_keras.src.optimizers import adamax
from tf_keras.src.optimizers import adamw
from tf_keras.src.optimizers import ftrl
from tf_keras.src.optimizers import lion
from tf_keras.src.optimizers import nadam
from tf_keras.src.optimizers import optimizer as base_optimizer
from tf_keras.src.optimizers import rmsprop
from tf_keras.src.optimizers import sgd
from tf_keras.src.optimizers.legacy import adadelta as adadelta_legacy
from tf_keras.src.optimizers.legacy import adagrad as adagrad_legacy
from tf_keras.src.optimizers.legacy import adam as adam_legacy
from tf_keras.src.optimizers.legacy import adamax as adamax_legacy
from tf_keras.src.optimizers.legacy import ftrl as ftrl_legacy
from tf_keras.src.optimizers.legacy import (
    gradient_descent as gradient_descent_legacy,
)
from tf_keras.src.optimizers.legacy import nadam as nadam_legacy
from tf_keras.src.optimizers.legacy import optimizer_v2 as base_optimizer_legacy
from tf_keras.src.optimizers.legacy import rmsprop as rmsprop_legacy
from tf_keras.src.optimizers.legacy.adadelta import Adadelta
from tf_keras.src.optimizers.legacy.adagrad import Adagrad
from tf_keras.src.optimizers.legacy.adam import Adam
from tf_keras.src.optimizers.legacy.adamax import Adamax
from tf_keras.src.optimizers.legacy.ftrl import Ftrl

# Symbols to be accessed under keras.optimizers. To be replaced with
# optimizers v2022 when they graduate out of experimental.
from tf_keras.src.optimizers.legacy.gradient_descent import SGD
from tf_keras.src.optimizers.legacy.nadam import Nadam
from tf_keras.src.optimizers.legacy.rmsprop import RMSprop
from tf_keras.src.optimizers.optimizer_v1 import Optimizer
from tf_keras.src.optimizers.optimizer_v1 import TFOptimizer
from tf_keras.src.optimizers.schedules import learning_rate_schedule
from tf_keras.src.saving.legacy import serialization as legacy_serialization
from tf_keras.src.saving.serialization_lib import deserialize_keras_object
from tf_keras.src.saving.serialization_lib import serialize_keras_object

# isort: off
from tensorflow.python.util.tf_export import keras_export

# pylint: disable=line-too-long


@keras_export("keras.optimizers.serialize")
def serialize(optimizer, use_legacy_format=False):
    """Serialize the optimizer configuration to JSON compatible python dict.

    The configuration can be used for persistence and reconstruct the
    `Optimizer` instance again.

    >>> tf.keras.optimizers.serialize(tf.keras.optimizers.legacy.SGD())
    {'module': 'keras.optimizers.legacy', 'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 'decay': 0.0, 'momentum': 0.0, 'nesterov': False}, 'registered_name': None}
    """  # noqa: E501
    """
    Args:
      optimizer: An `Optimizer` instance to serialize.

    Returns:
      Python dict which contains the configuration of the input optimizer.
    """
    if optimizer is None:
        return None
    if not isinstance(
        optimizer,
        (
            base_optimizer.Optimizer,
            Optimizer,
            base_optimizer_legacy.OptimizerV2,
        ),
    ):
        warnings.warn(
            "The `keras.optimizers.serialize()` API should only be used for "
            "objects of type `keras.optimizers.Optimizer`. Found an instance "
            f"of type {type(optimizer)}, which may lead to improper "
            "serialization."
        )
    if use_legacy_format:
        return legacy_serialization.serialize_keras_object(optimizer)
    return serialize_keras_object(optimizer)


def is_arm_mac():
    return platform.system() == "Darwin" and platform.processor() == "arm"


@keras_export("keras.optimizers.deserialize")
def deserialize(config, custom_objects=None, use_legacy_format=False, **kwargs):
    """Inverse of the `serialize` function.

    Args:
        config: Optimizer configuration dictionary.
        custom_objects: Optional dictionary mapping names (strings) to custom
            objects (classes and functions) to be considered during
            deserialization.

    Returns:
        A TF-Keras Optimizer instance.
    """
    # loss_scale_optimizer has a direct dependency of optimizer, import here
    # rather than top to avoid the cyclic dependency.
    from tf_keras.src.mixed_precision import (
        loss_scale_optimizer,
    )

    use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
    if kwargs:
        raise TypeError(f"Invalid keyword arguments: {kwargs}")
    if len(config["config"]) > 0:
        # If the optimizer config is not empty, then we use the value of
        # `is_legacy_optimizer` to override `use_legacy_optimizer`. If
        # `is_legacy_optimizer` does not exist in config, it means we are
        # using the legacy optimzier.
        use_legacy_optimizer = config["config"].get("is_legacy_optimizer", True)
    if (
        tf.__internal__.tf2.enabled()
        and tf.executing_eagerly()
        and not is_arm_mac()
        and not use_legacy_optimizer
    ):
        # We observed a slowdown of optimizer on M1 Mac, so we fall back to the
        # legacy optimizer for M1 users now, see b/263339144 for more context.
        all_classes = {
            "adadelta": adadelta.Adadelta,
            "adagrad": adagrad.Adagrad,
            "adam": adam.Adam,
            "adamw": adamw.AdamW,
            "adamax": adamax.Adamax,
            "experimentaladadelta": adadelta.Adadelta,
            "experimentaladagrad": adagrad.Adagrad,
            "experimentaladam": adam.Adam,
            "experimentalsgd": sgd.SGD,
            "nadam": nadam.Nadam,
            "rmsprop": rmsprop.RMSprop,
            "sgd": sgd.SGD,
            "ftrl": ftrl.Ftrl,
            "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizerV3,
            "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
            # LossScaleOptimizerV1 was an old version of LSO that was removed.
            # Deserializing it turns it into a LossScaleOptimizer
            "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
        }
    else:
        all_classes = {
            "adadelta": adadelta_legacy.Adadelta,
            "adagrad": adagrad_legacy.Adagrad,
            "adam": adam_legacy.Adam,
            "adamax": adamax_legacy.Adamax,
            "experimentaladadelta": adadelta.Adadelta,
            "experimentaladagrad": adagrad.Adagrad,
            "experimentaladam": adam.Adam,
            "experimentalsgd": sgd.SGD,
            "nadam": nadam_legacy.Nadam,
            "rmsprop": rmsprop_legacy.RMSprop,
            "sgd": gradient_descent_legacy.SGD,
            "ftrl": ftrl_legacy.Ftrl,
            "lossscaleoptimizer": loss_scale_optimizer.LossScaleOptimizer,
            "lossscaleoptimizerv3": loss_scale_optimizer.LossScaleOptimizerV3,
            # LossScaleOptimizerV1 was an old version of LSO that was removed.
            # Deserializing it turns it into a LossScaleOptimizer
            "lossscaleoptimizerv1": loss_scale_optimizer.LossScaleOptimizer,
        }

    # Make deserialization case-insensitive for built-in optimizers.
    if config["class_name"].lower() in all_classes:
        config["class_name"] = config["class_name"].lower()

    if use_legacy_format:
        return legacy_serialization.deserialize_keras_object(
            config,
            module_objects=all_classes,
            custom_objects=custom_objects,
            printable_module_name="optimizer",
        )

    return deserialize_keras_object(
        config,
        module_objects=all_classes,
        custom_objects=custom_objects,
        printable_module_name="optimizer",
    )


@keras_export(
    "keras.__internal__.optimizers.convert_to_legacy_optimizer", v1=[]
)
def convert_to_legacy_optimizer(optimizer):
    """Convert experimental optimizer to legacy optimizer.

    This function takes in a `keras.optimizers.Optimizer`
    instance and converts it to the corresponding
    `keras.optimizers.legacy.Optimizer` instance.
    For example, `keras.optimizers.Adam(...)` to
    `keras.optimizers.legacy.Adam(...)`.

    Args:
        optimizer: An instance of `keras.optimizers.Optimizer`.
    """
    # loss_scale_optimizer has a direct dependency of optimizer, import here
    # rather than top to avoid the cyclic dependency.
    from tf_keras.src.mixed_precision import (
        loss_scale_optimizer,
    )

    if not isinstance(optimizer, base_optimizer.Optimizer):
        raise ValueError(
            "`convert_to_legacy_optimizer` should only be called "
            "on instances of `tf.keras.optimizers.Optimizer`, but "
            f"received {optimizer} of type {type(optimizer)}."
        )
    optimizer_name = optimizer.__class__.__name__.lower()
    config = optimizer.get_config()
    # Remove fields that only exist in experimental optimizer.
    keys_to_remove = [
        "weight_decay",
        "use_ema",
        "ema_momentum",
        "ema_overwrite_frequency",
        "jit_compile",
        "is_legacy_optimizer",
    ]
    for key in keys_to_remove:
        config.pop(key, None)

    if isinstance(optimizer, loss_scale_optimizer.LossScaleOptimizerV3):
        # For LossScaleOptimizers, recursively convert the inner optimizer
        config["inner_optimizer"] = convert_to_legacy_optimizer(
            optimizer.inner_optimizer
        )
        if optimizer_name == "lossscaleoptimizerv3":
            optimizer_name = "lossscaleoptimizer"

    # Learning rate can be a custom LearningRateSchedule, which is stored as
    # a dict in config, and cannot be deserialized.
    if hasattr(optimizer, "_learning_rate") and isinstance(
        optimizer._learning_rate, learning_rate_schedule.LearningRateSchedule
    ):
        config["learning_rate"] = optimizer._learning_rate
    legacy_optimizer_config = {
        "class_name": optimizer_name,
        "config": config,
    }
    return deserialize(
        legacy_optimizer_config,
        use_legacy_optimizer=True,
        use_legacy_format=True,
    )


@keras_export("keras.optimizers.get")
def get(identifier, **kwargs):
    """Retrieves a TF-Keras Optimizer instance.

    Args:
        identifier: Optimizer identifier, one of - String: name of an optimizer
          - Dictionary: configuration dictionary. - TF-Keras Optimizer instance
          (it will be returned unchanged). - TensorFlow Optimizer instance (it
          will be wrapped as a TF-Keras Optimizer).

    Returns:
        A TF-Keras Optimizer instance.

    Raises:
        ValueError: If `identifier` cannot be interpreted.
    """
    use_legacy_optimizer = kwargs.pop("use_legacy_optimizer", False)
    if kwargs:
        raise TypeError(f"Invalid keyword arguments: {kwargs}")
    if isinstance(
        identifier,
        (
            Optimizer,
            base_optimizer_legacy.OptimizerV2,
        ),
    ):
        return identifier
    elif isinstance(identifier, base_optimizer.Optimizer):
        if tf.__internal__.tf2.enabled():
            return identifier
        else:
            # If TF2 is disabled, we convert to the legacy
            # optimizer.
            return convert_to_legacy_optimizer(identifier)

    # Wrap legacy TF optimizer instances
    elif isinstance(identifier, tf.compat.v1.train.Optimizer):
        opt = TFOptimizer(identifier)
        backend.track_tf_optimizer(opt)
        return opt
    elif isinstance(identifier, dict):
        use_legacy_format = "module" not in identifier
        return deserialize(
            identifier,
            use_legacy_optimizer=use_legacy_optimizer,
            use_legacy_format=use_legacy_format,
        )
    elif isinstance(identifier, str):
        config = {"class_name": str(identifier), "config": {}}
        return get(
            config,
            use_legacy_optimizer=use_legacy_optimizer,
        )
    else:
        raise ValueError(
            f"Could not interpret optimizer identifier: {identifier}"
        )

