# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Autograph specifc overrides for tf.data.ops."""
import functools

import numpy as np

from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_conversion
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import cond
from tensorflow.python.util import nest


# TODO(mdan): These checks should be easier. Fix the nest API.
def _verify_spec_compatible(input_name, spec_name, input_, spec):
  """Verifies that a symbol has a type compatible vith a given spec.

  Here, compatibility is viewed in the general TensorFlow sense: that the dtypes
  are the same after implicit conversion, if both are tensors.

  This verifier ensures consistent treatment of types across AutoGraph.

  Args:
    input_name: A name to use for `input_` in error messages.
    spec_name: A name to use for `spec` in error messages.
    input_: Any, value to verify.
    spec: TypeSpec that `input_` must be compatible with.

  Raises:
    ValueError if the two types have been determined not to be compatible.
  """
  assert isinstance(spec, tensor_spec.TensorSpec)
  if input is None:
    # TODO(mdan): raise from None when switching to Py3.
    raise ValueError("{} cannot be None".format(input_name))

  # TODO(mdan): Use TensorCompatible when ready.
  if isinstance(input_, (bool, int, float, str, np.ndarray)):
    input_ = tensor_conversion.convert_to_tensor_v2(input_)

  input_dtype = getattr(input_, "dtype", None)

  if input_dtype != spec.dtype:
    input_dtype_str = "no dtype" if input_dtype is None else str(input_dtype)

    raise TypeError(
        "{} must have the same dtype as {}. Expected {}, got {}".format(
            input_name, spec_name, spec.dtype, input_dtype_str
        )
    )


def _verify_structure_compatible(input_name, spec_name, input_, spec):
  """Verifies that possibly-structured symbol has types compatible vith another.

  See _verify_spec_compatible for a more concrete meaning of "compatible".
  Unspec _verify_spec_compatible, which handles singular Tensor-spec objects,
  verify_structures_compatible can process structures recognized by tf.nest.

  Args:
    input_name: A name to use for `input_` in error messages.
    spec_name: A name to use for `spec` in error messages.
    input_: Any, value to verify. May, but doesn't need to, be a structure.
    spec: Any, value that `input_` must be compatible with. May, but doesn't
      need to, be a structure.

  Raises:
    ValueError if the two types have been determined not to be compatible.
  """
  try:
    nest.assert_same_structure(input_, spec, expand_composites=True)
  except (ValueError, TypeError) as e:
    raise TypeError(
        "{} must have the same element structure as {}.\n\n{}".format(
            input_name, spec_name, str(e)
        )
    ) from e

  nest.map_structure(
      functools.partial(_verify_spec_compatible, input_name, spec_name), input_,
      spec)


def _next_tf_iterator(iterator, default=py_builtins.UNSPECIFIED):
  if default is py_builtins.UNSPECIFIED:
    # Without a default, fall back to the "normal" behavior which raises
    # a runtime exception.
    return next(iterator)
  opt_iterate = iterator.get_next_as_optional()
  _verify_structure_compatible(
      "the default argument", "the iterate", default, iterator.element_spec
  )
  return cond.cond(
      opt_iterate.has_value(), opt_iterate.get_value, lambda: default
  )


def register_overrides():
  py_builtins.next_registry.register(
      iterator_ops.OwnedIterator, _next_tf_iterator
  )
  control_flow.for_loop_registry.register(
      iterator_ops.OwnedIterator, control_flow._tf_iterator_for_stmt  # pylint: disable=protected-access
  )
