#  Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import string
from io import StringIO

from paddle import _C_ops, _legacy_C_ops

from ..base.data_feeder import check_variable_and_dtype
from ..base.proto import framework_pb2
from ..common_ops_import import Variable
from ..framework import (
    LayerHelper,
    OpProtoHolder,
    convert_np_dtype_to_dtype_,
    core,
    in_dynamic_mode,
    in_dynamic_or_pir_mode,
)

__all__ = []


def _convert_(name):
    """
    Formatting.

    Args:
       name: The name/alias

    This function takes in a name and converts it to a standard format of
    group1_group2. Where as per the regular expression, group1 can have
    alphabets and numbers and group2 has capital alphabets.

    """
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def _type_to_str_(tp):
    return framework_pb2.AttrType.Name(tp)


_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")


def escape_math(text):
    # return _two_bang_pattern_.sub(
    #    r'$$\1$$',
    #    _single_dollar_pattern_.sub(r':math:\n`\1`',
    #                                _two_dollar_pattern_.sub(r"!!\1!!", text)))
    return _two_dollar_pattern_.sub(r':math:`\1`', text)


def _generate_doc_string_(
    op_proto, additional_args_lines=None, skip_attrs_set=None
):
    """
    Generate docstring by OpProto

    Args:
        op_proto (framework_pb2.OpProto): a protobuf message typed OpProto

    Returns:
        str: the document string
    """

    if not isinstance(op_proto, framework_pb2.OpProto):
        raise TypeError("OpProto should be `framework_pb2.OpProto`")

    buf = StringIO()
    buf.write(escape_math(op_proto.comment))
    buf.write('\nArgs:\n')
    for each_input in op_proto.inputs:
        line_begin = f'    {_convert_(each_input.name)}'
        buf.write(line_begin)
        buf.write(" (Tensor): ")
        buf.write(escape_math(each_input.comment))
        if each_input.duplicable:
            buf.write("  Duplicatable.")
        if each_input.dispensable:
            buf.write("  Optional.")
        buf.write('\n')

    skip_attrs = OpProtoHolder.generated_op_attr_names()
    # attr use_mkldnn and is_test also should not be visible to users.
    skip_attrs.add("use_mkldnn")
    skip_attrs.add("is_test")
    skip_attrs.add("use_cudnn")

    if skip_attrs_set:
        for t in skip_attrs_set:
            skip_attrs.add(t)

    for each_attr in op_proto.attrs:
        if each_attr.name in skip_attrs:
            continue
        buf.write('    ')
        buf.write(each_attr.name)
        buf.write(' (')
        buf.write(_type_to_str_(each_attr.type))
        buf.write('): ')
        buf.write(escape_math(each_attr.comment))
        buf.write('\n')

    if additional_args_lines is not None:
        for line in additional_args_lines:
            line = line.strip()
            buf.write('    ')
            buf.write(line)
            buf.write('\n')

    if len(op_proto.outputs) != 0:
        buf.write('\nReturns:\n')
        buf.write('    ')
        for each_opt in op_proto.outputs:
            if not each_opt.intermediate:
                break
        buf.write(_convert_(each_opt.name))
        buf.write(' (Tensor): ')
        buf.write(escape_math(each_opt.comment))

    return buf.getvalue()


def generate_layer_fn(op_type):
    """Register the Python layer for an Operator.

    Args:
       op_type: The name of the operator to be created.

    This function takes in the operator type (sigmoid, mean , average etc) and
    creates the operator functionality.

    """
    op_proto = OpProtoHolder.instance().get_op_proto(op_type)
    not_intermediate_outputs = [
        output for output in op_proto.outputs if not output.intermediate
    ]
    intermediate_outputs = [
        output for output in op_proto.outputs if output.intermediate
    ]

    if len(not_intermediate_outputs) != 1:
        raise ValueError(
            "Only one non intermediate output operator can be",
            f"automatically generated. {op_type}",
        )

    if not_intermediate_outputs[0].duplicable:
        raise ValueError(
            "Only non duplicable op can be automatically generated."
        )

    for output in intermediate_outputs:
        if output.duplicable:
            raise ValueError(
                "The op can be automatically generated only when ",
                "all intermediate ops are not duplicable.",
            )

    o_name = not_intermediate_outputs[0].name
    intermediate_output_names = [output.name for output in intermediate_outputs]

    def infer_and_check_dtype(op_proto, *args, **kwargs):
        """
        This function performs the sanity check for dtype and
        instance type.
        """
        dtype = None
        for ipt in op_proto.inputs:
            name = _convert_(ipt.name)
            val = kwargs.pop(name, [])
            if not isinstance(val, list) and not isinstance(val, tuple):
                val = [val]
            if len(val) == 0:
                if len(args) == 0:
                    continue
                val = [args[0]]
                args = args[1:]

            for each in val:
                if not isinstance(each, Variable):
                    raise ValueError(f"input of {op_type} must be variable")

                if dtype is None:
                    dtype = each.dtype
                elif dtype != each.dtype:
                    raise ValueError(
                        "operator {} must input same dtype. {} vs {}".format(
                            op_type, dtype, each.dtype
                        )
                    )

        if dtype is None:
            arg_dtype = kwargs.get("dtype")
            if arg_dtype:
                if not isinstance(arg_dtype, core.VarDesc.VarType):
                    dtype = convert_np_dtype_to_dtype_(arg_dtype)
                else:
                    dtype = arg_dtype
            else:
                dtype = core.VarDesc.VarType.FP32
        return dtype

    def func(*args, **kwargs):
        helper = LayerHelper(op_type, **kwargs)

        dtype = infer_and_check_dtype(op_proto, *args, **kwargs)

        inputs = {}
        for ipt in op_proto.inputs:
            name = _convert_(ipt.name)
            val = kwargs.pop(name, [])
            if not isinstance(val, list) and not isinstance(val, tuple):
                val = [val]
            if len(val) == 0 and len(args) != 0:
                val = args[0]
                args = args[1:]
            inputs[ipt.name] = val

        outputs = {}
        out = kwargs.pop(_convert_(o_name), [])
        if out:
            out_var = out[0] if isinstance(out, (list, tuple)) else out
        else:
            out_var = helper.create_variable_for_type_inference(dtype=dtype)
        outputs[o_name] = [out_var]
        for name in intermediate_output_names:
            outputs[name] = [
                helper.create_variable_for_type_inference(dtype=dtype)
            ]
        helper.append_op(
            type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
        )
        return helper.append_activation(out_var)

    func.__name__ = op_type
    func.__doc__ = _generate_doc_string_(op_proto)
    return func


def generate_activation_fn(op_type):
    """Register the Python layer for an Operator without Attribute.

    Args:
       op_type: The name of the operator to be created.

    This function takes in the operator type (sigmoid, exp , tanh etc) and
    creates the operator functionality.

    """
    op_proto = OpProtoHolder.instance().get_op_proto(op_type)

    def func(x, name=None):
        if in_dynamic_or_pir_mode():
            if hasattr(_C_ops, op_type):
                op = getattr(_C_ops, op_type)
                return op(x)
            else:
                # TODO(dev): Because some ops' yaml has not been migrated.
                # Replace it with _C_ops while all yaml work is done.
                op = getattr(_legacy_C_ops, op_type)
                return op(x)
        else:
            if op_type not in ["abs", "exp", "square"]:
                check_variable_and_dtype(
                    x, 'x', ['float16', 'float32', 'float64'], op_type
                )
            else:
                # abs exp square ops support dtype(int32, int64, float16, float32, float64)
                check_variable_and_dtype(
                    x,
                    'x',
                    [
                        'int32',
                        'int64',
                        'float16',
                        'float32',
                        'float64',
                        'complex64',
                        'complex128',
                        'uint16',
                    ],
                    op_type,
                )

            helper = LayerHelper(op_type, **locals())

            output = helper.create_variable_for_type_inference(dtype=x.dtype)
            helper.append_op(
                type=op_type, inputs={"X": x}, outputs={"Out": output}
            )
            return output

    func.__name__ = op_type
    if op_type == 'abs':
        func.__doc__ = r"""

Abs Operator.
Perform elementwise abs for input `X`.

.. math::

    out = |x|

Args:
    x (Tensor): The input tensor of abs op.
    out (Tensor): The output tensor of abs op.
    name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
"""
    else:
        func.__doc__ = _generate_doc_string_(
            op_proto,
            additional_args_lines=[
                "name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
            ],
        )
    return func


def generate_inplace_fn(inplace_op_type):
    """Register the Python layer for an Inplace Operator without Attribute.

    Args:
       inplace_op_type: The name of the inplace operator to be created.

    This function takes in the inplace operator type (exp_ , ceil_ etc) and
    creates the operator functionality.
    """
    origin_op_type = inplace_op_type[:-1]

    def func(x, name=None):
        if in_dynamic_mode():
            if hasattr(_C_ops, inplace_op_type):
                op = getattr(_C_ops, inplace_op_type)
                return op(x)
            else:
                op = getattr(_legacy_C_ops, inplace_op_type)
                return op(x)

    func.__name__ = inplace_op_type
    func.__doc__ = """
Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_{}`.
""".format(
        origin_op_type, origin_op_type
    )
    return func


def templatedoc(op_type=None):
    """
    Decorator of layer function. It will use the docstring from the layer
    function as the template. The template arguments are:

    * ${comment}: The operator comment written in CPP.
    * ${{name}_comment}: The comment of ${name} written with AddAttr, AddOutput,
        and AddInput. The ${name} is Python snake style. i.e., xxx_xxx.
    * ${{name}_type}: The type of ${name}.

    Returns:
        Decorated function.
    """

    def trim_ending_dot(msg):
        return msg.rstrip('.')

    def __impl__(func):
        if op_type is None:
            op_type_name = func.__name__
        else:
            op_type_name = op_type
        op_proto = OpProtoHolder.instance().get_op_proto(op_type_name)
        tmpl = string.Template(func.__doc__)

        comment_lines = op_proto.comment.split("\n")
        comment = ""
        for line in comment_lines:
            line = line.strip()
            if len(line) != 0:
                comment += escape_math(line)
                comment += " "
            elif len(comment) != 0:
                comment += "\n    \n    "

        args = {"comment": trim_ending_dot(comment)}
        for each_input in op_proto.inputs:
            input_name = _convert_(each_input.name)
            args[f"{input_name}_comment"] = trim_ending_dot(each_input.comment)
            args[f"{input_name}_type"] = "Variable"
        for each_attr in op_proto.attrs:
            input_name = _convert_(each_attr.name)
            args[f"{input_name}_comment"] = trim_ending_dot(each_attr.comment)
            args[f"{input_name}_type"] = _type_to_str_(each_attr.type)

        for each_opt in op_proto.outputs:
            output_name = _convert_(each_opt.name)
            args[f"{output_name}_comment"] = trim_ending_dot(each_opt.comment)
            args[f"{output_name}_type"] = "Variable"
        func.__doc__ = tmpl.substitute(args)
        return func

    return __impl__


def add_sample_code(func, sample_code):
    """
    Append sample code for dynamically generated functions.

    Args:
       func: The function of the function to be append sample code to.
       sample_code: sample code session in rst format.
    """
    func.__doc__ = func.__doc__ + sample_code
