#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from functools import partial, reduce

import paddle
from paddle import _C_ops
from paddle.base import core
from paddle.base.backward import _infer_var_data_type_shape_
from paddle.base.framework import (
    Operator,
    Program,
    Variable,
    in_pir_mode,
    static_only,
)
from paddle.base.libpaddle.pir import build_if_op, build_while_op, cf_yield
from paddle.common_ops_import import (
    LayerHelper,
    check_type,
    check_variable_and_dtype,
    convert_dtype,
    in_dygraph_mode,
)
from paddle.utils import (
    assert_same_structure,
    copy_mutable_vars,
    flatten,
    hold_mutable_vars,
    is_sequence,
    map_structure,
    pack_sequence_as,
    to_sequence,
)


def Assert(cond, data=None, summarize=20, name=None):
    '''
    This API creates an op that asserts the given condition is true. If the
    condition is false, prints the tensors in data. ``summarize`` specifies the
    number of the elements in the tensors to print.

    Args:
        cond (Tensor): The boolean condition tensor whose numel should be 1.
        data (list|tuple, optional): list or tuple of tensors to print when
            condition is not true. If it's ``None``, no tensor will be printed.
            The default value is ``None``.
        summarize (int, optional): Number of elements in the tensor to be
            printed. If its value is -1, then all elements in the tensor will
            be printed. The default value is 20.
        name (str, optional): The default value is ``None`` . Normally users
            don't have to set this parameter. For more information, please
            refer to :ref:`api_guide_Name` .

    Returns:
        Operator: the created operation.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> from paddle.static.nn.control_flow import Assert

            >>> paddle.enable_static()
            >>> x = paddle.full([2, 3], 2.0, 'float32')
            >>> condition = paddle.max(x) < 1.0 # False
            >>> Assert(condition, [x], 10, "example_assert_layer")

            >>> exe = paddle.static.Executor()
            >>> try:
            ...     exe.run(paddle.static.default_main_program())
            ...     # Print x and throws ValueError
            ...     # Example printed message for x:
            ...     #
            ...     # Variable: fill_constant_0.tmp_0
            ...     #   - lod: {}
            ...     #   - place: CPUPlace()
            ...     #   - shape: [2, 3]
            ...     #   - layout: NCHW
            ...     #   - dtype: float
            ...     #   - data: [2 2 2 2 2 2]
            ... except ValueError as e:
            ...     print("Assert Exception Example")

    '''
    check_variable_and_dtype(
        cond, "cond", ["bool"], "static.nn.control_flow.Assert"
    )
    check_type(
        data, "data", (list, tuple, type(None)), "static.nn.control_flow.Assert"
    )
    check_type(summarize, "summarize", int, "static.nn.control_flow.Assert")
    check_type(name, "name", (str, type(None)), "static.nn.control_flow.Assert")

    layer_name = name if name else ('assert_' + cond.name)
    helper = LayerHelper(layer_name, **locals())

    op = helper.append_op(
        type="assert",
        inputs={"Cond": cond, "Data": [] if data is None else list(data)},
        attrs={"summarize": summarize},
    )

    return op


class BlockGuard:
    """
    BlockGuard class.

    BlockGuard class is used to create a sub-block in a program by
    using the Python `with` keyword.
    """

    def __init__(self, main_program):
        if not isinstance(main_program, Program):
            raise TypeError("BlockGuard takes a program")
        self.main_program = main_program

    def __enter__(self):
        self.main_program._create_block()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.main_program._rollback()
        if exc_type is not None:
            return False  # re-raise exception
        return True


class WhileGuard(BlockGuard):
    def __init__(self, while_op):
        if not isinstance(while_op, While):
            raise TypeError("WhileGuard takes a while op")
        super().__init__(while_op.helper.main_program)
        self.while_op = while_op

    def __enter__(self):
        self.while_op.status = While.IN_WHILE_BLOCK
        return super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is not None:
            return False
        self.while_op.status = While.AFTER_WHILE_BLOCK
        self.while_op._complete()
        return super().__exit__(exc_type, exc_val, exc_tb)


class ConditionalBlock:
    '''
    **ConditionalBlock**

    ConditionalBlock is an operator that bind a block to a specific condition,
    if the condition matches, the corresponding block will be executed.

    Args:
        inputs (Variable): bool conditions.
        is_scalar_condition (bool): whether the branch is controlled by a scalar.
        name(str): name of this ConditionalBlock.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> from paddle.static.nn.control_flow import ConditionalBlock

            >>> label = paddle.rand([1])
            >>> limit = paddle.ones([1]) * 0.5
            >>> cond = paddle.less_than(x=label, y=limit)
            >>> image = paddle.ones([1])

            >>> true_image = image[cond]
            >>> true_cond = ConditionalBlock([true_image])

            >>> with true_cond.block():
            ...     pass
            >>> with false_cond.block():
            ...     pass
    '''

    def __init__(self, inputs, is_scalar_condition=False, name=None):
        for each_input in inputs:
            check_type(each_input, "input", Variable, "ConditionalBlock")
        self.inputs = inputs
        self.is_scalar_condition = is_scalar_condition
        self.helper = LayerHelper('conditional_block', name=name)

    def block(self):
        return ConditionalBlockGuard(self)

    def complete(self):
        inside_block = self.helper.main_program.current_block()
        parent_block = self.helper.main_program.block(inside_block.parent_idx)

        intermediate = set()
        params = set()
        params, intermediate = get_inputs_outputs_in_block(
            inside_block, params, intermediate, helper=self.helper
        )

        # Todo(liym27) Here assume that all params are in recursive parent block
        # but when minimize() called in control flow, some params may be in
        # conditional grad block
        param_list = [
            parent_block._var_recursive(each_name) for each_name in params
        ]

        out_list = []
        for inner_out_name in intermediate:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_list.append(inner_var)

        step_scope = parent_block.create_var(
            type=core.VarDesc.VarType.STEP_SCOPES
        )
        conditional_block_op = parent_block.append_op(
            type='conditional_block',
            inputs={
                'Cond': self.inputs,
                'Input': param_list,
            },
            outputs={'Out': out_list, 'Scope': [step_scope]},
            attrs={
                'sub_block': inside_block,
                'is_scalar_condition': self.is_scalar_condition,
            },
        )

        if self.need_append_conditional_block_grad(inside_block):
            self.append_conditional_block_grad(
                parent_block, inside_block, conditional_block_op
            )

    def need_append_conditional_block_grad(self, inside_block):
        grad_sub_block_idx = inside_block.backward_block_idx
        inside_block_idx = inside_block.idx

        # if inside_block have grad_block and grad_block is not itself,
        # we will append conditional block grad.
        return (
            grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
        )

    def append_conditional_block_grad(
        self, parent_block, inside_block, conditional_block_op
    ):
        '''
        Append op `conditional_block_grad` manually.
        When `optimizer.minimize/append_backward` is called in Paddle control flow,
        grad ops will be appended before appending op `conditional_block` so that
        op `conditional_block_grad` can't be appended when calling
        `optimizer.minimize/append_backward`. After appending op `conditional_block`,
        `conditional_block_grad` is appended manually.

        Args:
            parent_block (Block): The block that `conditional_block_op` blongs to.
            inside_block (Block): The sub block of `conditional_block_op`.
            conditional_block_op (Operator): The forward op conditional_block.
        '''

        grad_sub_block_idx = inside_block.backward_block_idx
        grad_sub_block = self.helper.main_program.block(grad_sub_block_idx)

        intermediate = set()
        params = set()

        for each_op in grad_sub_block.ops:
            assert isinstance(each_op, Operator)
            for iname in each_op.input_names:
                for in_var_name in each_op.input(iname):
                    if in_var_name not in intermediate:
                        params.add(in_var_name)

            for oname in each_op.output_names:
                for out_var_name in each_op.output(oname):
                    intermediate.add(out_var_name)

        param_list = []
        for inner_input_name in params:
            inner_var = parent_block._find_var_recursive(inner_input_name)
            if inner_var:
                param_list.append(inner_var.name)

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            conditional_block_op.desc, set(), [grad_sub_block.desc]
        )

        # append op_desc in grad_op_descs to target_block
        op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
        backward = core.op_proto_and_checker_maker.OpRole.Backward
        new_op_desc = parent_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc[0])
        new_op_desc._set_attr(op_role_attr_name, backward)
        # set input and output manually
        new_op_desc.set_input('Input', param_list)
        new_op_desc.set_output(
            'Input@GRAD', [param + "@GRAD" for param in param_list]
        )

        new_vars = set()
        for grad_var_name in new_op_desc.output_arg_names():
            if (
                grad_sub_block.desc.has_var_recursive(grad_var_name.encode())
                or grad_var_name == core.empty_var_name()
            ):
                continue
            grad_sub_block.desc.var(grad_var_name.encode())
            new_vars.add(grad_var_name)
            if grad_var_name not in op_grad_to_var:
                continue

        # infer_shape and infer_type
        new_op_desc.infer_var_type(grad_sub_block.desc)
        new_op_desc.infer_shape(grad_sub_block.desc)

        for arg in new_op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_shape_(arg, grad_sub_block)

        self.helper.main_program._sync_with_cpp()


class ConditionalBlockGuard(BlockGuard):
    """
    ConditionalBlockGuard is derived from BlockGuard. It is dedicated for
    holding a ConditionalBlock, and helping users entering and exiting the
    ConditionalBlock via Python's 'with' keyword. However, ConditionalBlockGuard
    is generally an internal component of IfElse, users should not use it directly.
    """

    def __init__(self, block):
        check_type(block, "block", ConditionalBlock, "ConditionalBlockGuard")
        super().__init__(block.helper.main_program)
        self.block = block

    def __enter__(self):
        return super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.block.complete()
        return super().__exit__(exc_type, exc_val, exc_tb)


def get_inputs_outputs_in_block(
    current_block, inner_inputs, inner_outputs, helper
):
    """
    Find inputs and outputs in current control flow block.
    :param current_block: Current control flow block.
    :param inner_inputs: Input var name of ops in current block.
    :param inner_outputs: Output var name of ops in current block.
    :return: inner_inputs, inner_outputs
    """

    def is_ignore_vars(op, var_name):
        # NOTE(dev): There are some persistable var created in some non-standard API
        # such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
        # Input and Output. This var shall not be considered as a loop_var in
        # control_flow.
        IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
        if op.type in IGNORE_VAR_NAMES:
            var_names = IGNORE_VAR_NAMES[op.type]
            for name in var_names:
                if name in var_name:
                    return True
        return False

    # Step1: update inner_inputs and inner_outputs
    # NOTE: Here assumes that all variables are input or output of Ops,
    # but some variables are created without appendding a real op.
    # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
    for op in current_block.ops:
        assert isinstance(op, Operator)
        for iname in op.input_names:
            for in_var_name in op.input(iname):
                if in_var_name not in inner_outputs and not is_ignore_vars(
                    op, in_var_name
                ):
                    inner_inputs.add(in_var_name)

        for oname in op.output_names:
            for out_var_name in op.output(oname):
                inner_outputs.add(out_var_name)

    # Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
    remove_inner_inputs = set()
    parent_block = helper.main_program.block(current_block.parent_idx)

    for in_var_name in inner_inputs:
        parent_block_var = parent_block._find_var_recursive(in_var_name)
        current_block_var = None
        if current_block.has_var(in_var_name):
            current_block_var = current_block.var(in_var_name)
        if (
            not parent_block_var
            and current_block_var
            and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        ):
            remove_inner_inputs.add(in_var_name)

    inner_inputs = inner_inputs - remove_inner_inputs

    return inner_inputs, inner_outputs


class While:
    """
    :api_attr: Static Graph

    while loop control flow. Repeat while body until cond is False.

    Note:
        A new OP :ref:`api_paddle_static_nn_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
        OP :ref:`api_paddle_static_nn_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .

    Notice:
        Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
        As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
        out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
        code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.

    Args:
        cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
        name(str, optional): The default value is None.  Normally there is no need for user to set this property.  For more information, please refer to :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python
            :name: example-1

            >>> import paddle
            >>> import numpy as np

            >>> paddle.enable_static()

            >>> i = paddle.full(shape=[1], dtype='int64', fill_value=0)           # loop counter

            >>> loop_len = paddle.full(shape=[1],dtype='int64', fill_value=10)    # loop length

            >>> cond = paddle.less_than(x=i, y=loop_len)
            >>> while_op = paddle.static.nn.control_flow.While(cond=cond)
            >>> with while_op.block():
            ...     i = paddle.increment(x=i, value=1)
            ...     paddle.assign(paddle.less_than(x=i, y=loop_len), output=cond)

            >>> exe = paddle.static.Executor(paddle.CPUPlace())
            >>> exe.run(paddle.static.default_startup_program())

            >>> res = exe.run(paddle.static.default_main_program(), feed={}, fetch_list=[i])
            >>> print(res)
            [array([10], dtype=int64)]

        .. code-block:: python
            :name: example-2

            >>> import paddle
            >>> import numpy as np

            >>> paddle.enable_static()

            >>> i = paddle.full(shape=[1], dtype='int64', fill_value=0)
            >>> loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10)
            >>> one = paddle.full(shape=[1], dtype='float32', fill_value=1)
            >>> data = paddle.static.data(name='data', shape=[1], dtype='float32')
            >>> sums = paddle.full(shape=[1], dtype='float32', fill_value=0)  # Define the variable to be obtained >>> ouside of While, which name should be different from the variable inside the While to be obtained

            >>> cond = paddle.less_than(x=i, y=loop_len)
            >>> while_op = paddle.static.nn.control_flow.While(cond=cond)
            >>> with while_op.block():
            ...     sums_tensor = paddle.add(x=data, y=data)
            ...     paddle.assign(sums_tensor, sums)  # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
            ...     i = paddle.increment(x=i, value=1)
            ...     data = paddle.add(x=data, y=one)
            ...     paddle.assign(paddle.less_than(x=i, y=loop_len), output=cond)

            >>> feed_data = np.ones(1).astype('float32')
            >>> exe = paddle.static.Executor(paddle.CPUPlace())
            >>> exe.run(paddle.static.default_startup_program())
            >>> res = exe.run(paddle.static.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
            >>> print(res[0]) # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
            [2.]
    """

    BEFORE_WHILE_BLOCK = 0
    IN_WHILE_BLOCK = 1
    AFTER_WHILE_BLOCK = 2

    def __init__(self, cond, is_test=False, name=None):
        self.helper = LayerHelper("while", name=name)
        self.status = While.BEFORE_WHILE_BLOCK
        check_variable_and_dtype(cond, 'cond', ['bool'], 'static.nn.While')
        if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
            raise TypeError(
                "condition expected shape as [1], but given shape as {}.".format(
                    list(cond.shape)
                )
            )
        self.cond_var = cond
        self.is_test = is_test

    def block(self):
        return WhileGuard(self)

    def _complete(self):
        main_program = self.helper.main_program
        while_block = main_program.current_block()
        parent_block = main_program.block(
            main_program.current_block().parent_idx
        )

        inner_outputs = {self.cond_var.name}
        x_name_list = set()
        x_name_list, inner_outputs = get_inputs_outputs_in_block(
            while_block, x_name_list, inner_outputs, self.helper
        )

        out_vars = []
        for inner_out_name in inner_outputs:
            inner_var = parent_block._find_var_recursive(inner_out_name)
            if inner_var:
                out_vars.append(inner_var)

        x_name_list |= {x.name for x in out_vars}
        # NOTE(dev): cond_var has been contained in Input('Condition'), so
        # we remove it from Input('X')
        x_name_list -= {self.cond_var.name}

        step_scope = parent_block.create_var(
            type=core.VarDesc.VarType.STEP_SCOPES
        )

        parent_block.append_op(
            type='while',
            inputs={
                'X': [
                    parent_block._var_recursive(x_name)
                    for x_name in x_name_list
                ],
                'Condition': [self.cond_var],
            },
            outputs={'Out': out_vars, 'StepScopes': [step_scope]},
            attrs={'sub_block': while_block, "is_test": self.is_test},
        )


support_ret_buildin_type = (bool, float, int)


def assign_skip_lod_tensor_array(input, output):
    """
    Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
    """

    def has_shape_diff(x_var, y_var):
        if len(x_var.shape) != len(y_var.shape):
            return True
        for x_dim, y_dim in zip(x_var.shape, y_var.shape):
            if x_dim != y_dim and -1 not in [x_dim, y_dim]:
                return True
        return False

    if not isinstance(input, (Variable, core.eager.Tensor)):
        if isinstance(output, Variable) and isinstance(
            input, support_ret_buildin_type
        ):
            paddle.assign(input, output)
        else:
            output = input
        return

    if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        main_program = input.block.program
        parent_block = main_program.block(
            main_program.current_block().parent_idx
        )
        if parent_block and not parent_block._find_var_recursive(input.name):
            paddle.assign(input, output)
    else:
        if (
            isinstance(output, Variable)
            and isinstance(input, Variable)
            and has_shape_diff(input, output)
        ):
            warnings.warn(
                "In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
                    input.shape, output.shape
                )
            )
        # NOTE(dev): Avoid assign if input is output in Variable level which means
        # input is not generated in While sub block and modified by in-place and only
        # belong to inplace ops in constructing program process, because in-place pass
        # is only available in Graph level.
        with paddle.base.framework._stride_in_no_check_dy2st_diff():
            paddle.assign(input, output)


def while_loop(cond, body, loop_vars, is_test=False, name=None):
    """
    :api_attr: Static Graph

    while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.

    Notice:
        Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
        be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .

    Args:
        cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
            as many arguments as ``loop_vars`` .
        body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
            (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
        loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
        is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
        name(str, optional): Normally there is no need for users to set this property. For more information, please
            refer to :ref:`api_guide_Name`. Default is None.

    Returns:
        A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> paddle.enable_static()

            >>> def cond(i, ten):
            ...     return i < ten

            >>> def body(i, ten):
            ...     i = i + 1
            ...     return [i, ten]

            >>> main_program = paddle.static.default_main_program()
            >>> startup_program = paddle.static.default_startup_program()
            >>> with paddle.static.program_guard(main_program, startup_program):
            ...     i = paddle.full(shape=[1], fill_value=0, dtype='int64')     # loop counter
            ...     ten = paddle.full(shape=[1], fill_value=10, dtype='int64')  # loop length
            ...     i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])

            ...     exe = paddle.static.Executor(paddle.CPUPlace())
            ...     res = exe.run(main_program, feed={}, fetch_list=[i])
            ...     print(res)
            [array([10], dtype=int64)]
    """
    if not callable(cond):
        raise TypeError("cond in while_loop should be callable")
    if not callable(body):
        raise TypeError("body in while_loop should be callable")
    check_type(loop_vars, 'loop_vars', (list, tuple), 'static.nn.while_loop')
    if len(loop_vars) == 0:
        raise ValueError("loop_vars in while_loop should not be empty")

    pre_cond = cond(*loop_vars)

    if in_pir_mode():
        while_op = build_while_op(pre_cond, flatten(loop_vars))
        with while_op.body() as cur_block:
            args = cur_block.args()
            next_var = body(*args)
            next_cond = cond(*next_var)
            cf_yield([next_cond, *next_var])
        return while_op.as_operation().results()

    check_variable_and_dtype(
        pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
    )
    if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
        raise TypeError(
            "the shape of the variable returned by cond should be [1],"
            f"but given shape as {list(pre_cond.shape)}."
        )

    if in_dygraph_mode():
        now_cond = pre_cond.item()
        while now_cond:
            output_vars = body(*loop_vars)
            if not isinstance(output_vars, (list, tuple)):
                output_vars = [output_vars]
            if len(output_vars) != len(loop_vars):
                raise ValueError(
                    "body in while_loop should return the same arity "
                    "(length and structure) and types as loop_vars"
                )
            now_cond = cond(*output_vars).item()
            map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
        return loop_vars

    while_loop_block = While(pre_cond, is_test, name)
    has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
    with while_loop_block.block():
        # If a variable with mutable type is included in loop_vars, like `dict/list`,
        # modifying it in the body function will cause origin variable to be modified
        # synchronously. This will raise an assignment error out of while block.
        # Here we make a copy of the mutable vars to avoid this problem.
        if has_mutable_vars_in_loop:
            new_loop_vars = copy_mutable_vars(loop_vars)
            output_vars = body(*new_loop_vars)
        else:
            output_vars = body(*loop_vars)
        if not isinstance(output_vars, (list, tuple)):
            output_vars = [output_vars]
        try:
            loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
            assert_same_structure(output_vars, loop_vars, check_types=False)
        except ValueError as e:
            raise ValueError(
                "body in while_loop should return the same arity "
                f"(length and structure) as loop_vars: {e}"
            )
        now_cond = cond(*output_vars)
        map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
        paddle.assign(now_cond, pre_cond)
    return loop_vars


def _deal_with_undefined_var(output_vars, loop_vars):
    """Deal with undefined var cases, We create undefined variable based on the results of body().
    In Dy2Static, we use undefined var to represent the var created in control flow. This function
    expand the loop_vars and replace original loop_vars.
    1. UndefinedVar = Variable      # create a variable
    2. UndefinedVar = None          # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
    3. UndefinedVar = List(int)     # create a list of variable
    4. UndefinedVar = value         # create a variable
    """
    from paddle.jit.dy2static.utils import (
        UndefinedVar,
        create_undefined_variable,
    )

    def create_var_like(o_var):
        if (
            isinstance(o_var, (Variable,) + support_ret_buildin_type)
            or o_var is None
        ):
            return create_undefined_variable()
        if is_sequence(o_var):
            """
            Create a complex container class inside the body of while, including Python list and python Dict
            """
            return map_structure(lambda x: create_undefined_variable(), o_var)

    if len(output_vars) != len(loop_vars):
        raise ValueError("The length of loop_vars should be the same.")

    results = []
    for o_var, l_var in zip(output_vars, loop_vars):
        if isinstance(l_var, UndefinedVar) or l_var is None:
            results.append(create_var_like(o_var))
        else:
            results.append(l_var)
    return results


def _error_message(what, arg_name, op_name, right_value, error_value):
    error_message = (
        f"{what} of '{arg_name}' in {op_name} must be "
        f"{right_value}, but received: {error_value}."
    )

    return error_message


def case(pred_fn_pairs, default=None, name=None):
    '''
    :api_attr: Static Graph

    This operator works like an if-elif-elif-else chain.

    Args:
        pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor whose numel should be 1 (shape [] or shape [1]), ``fn`` is a callable. All callables return the same structure of Tensors.
        default(callable, optional): Callable that returns a structure of Tensors.
        name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor|list(Tensor): Tensors returned by the callable from the first pair whose pred is True,
        or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None,
        or Tensors returned by the last callable in ``pred_fn_pairs``  if no pred in ``pred_fn_pairs`` is True and ``default`` is None.

    Raises:
        TypeError: If the type of ``pred_fn_pairs`` is not list or tuple.
        TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple.
        TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2.
        TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not a Tensor.
        TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable.
        TypeError: If ``default`` is not None but it is not callable.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> paddle.enable_static()

            >>> def fn_1():
            ...     return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

            >>> def fn_2():
            ...     return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

            >>> def fn_3():
            ...     return paddle.full(shape=[3], dtype='int32', fill_value=3)

            >>> main_program = paddle.static.default_startup_program()
            >>> startup_program = paddle.static.default_main_program()

            >>> with paddle.static.program_guard(main_program, startup_program):
            ...     x = paddle.full(shape=[1], dtype='float32', fill_value=0.3)
            ...     y = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            ...     z = paddle.full(shape=[1], dtype='float32', fill_value=0.2)

            ...     pred_1 = paddle.less_than(z, x)  # true: 0.2 < 0.3
            ...     pred_2 = paddle.less_than(x, y)  # false: 0.3 < 0.1
            ...     pred_3 = paddle.equal(x, y)      # false: 0.3 == 0.1

            ...     # Call fn_1 because pred_1 is True
            ...     out_1 = paddle.static.nn.case(
            ...         pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)

            ...     # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
            ...     # because fn_3 is the last callable in pred_fn_pairs.
            ...     out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])

            ...     exe = paddle.static.Executor(paddle.CPUPlace())
            ...     res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
            ...     print(res_1, res_2)
            [[1. 1.]] [3 3 3]
    '''
    helper = LayerHelper('case', **locals())

    def _case_check_args(pred_fn_pairs, default):
        '''
        Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
        '''
        check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')

        for pred_fn in pred_fn_pairs:
            if not isinstance(pred_fn, tuple):
                raise TypeError(
                    _error_message(
                        "The elements' type",
                        "pred_fn_pairs",
                        "case",
                        tuple,
                        type(pred_fn),
                    )
                )
            if len(pred_fn) != 2:
                raise TypeError(
                    _error_message(
                        "The tuple's size",
                        "pred_fn_pairs",
                        "case",
                        "2",
                        str(len(pred_fn)) + "-tuple",
                    )
                )
            pred, fn = pred_fn

            if not isinstance(pred, Variable):
                raise TypeError(
                    _error_message(
                        "The pred's type",
                        "pred_fn_pairs",
                        "case",
                        "boolean Variable",
                        type(pred),
                    )
                )

            if not callable(fn):
                raise TypeError(
                    f"The fn for {pred.name} of pred_fn_pairs in Op(case) must"
                    " be callable."
                )

        if default is None:
            default_index = len(pred_fn_pairs) - 1  # pick the last one
            default = pred_fn_pairs[default_index][1]
            pred_fn_pairs = pred_fn_pairs[:default_index]
        elif not callable(default):
            raise TypeError("The default in Op(case) must be callable.")

        return pred_fn_pairs, default

    pred_fn_pairs, default = _case_check_args(pred_fn_pairs, default)

    false_fn = default
    for pred, true_fn in reversed(pred_fn_pairs):
        false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)

    final_fn = false_fn

    return final_fn()


def switch_case(branch_index, branch_fns, default=None, name=None):
    '''
    :api_attr: Static Graph

    This operator is like a C++ switch/case statement.

    Args:
        branch_index(Tensor): A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
        branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
        default(callable, optional): Callable that returns a structure of Tensors.
        name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
        or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
        or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.

    Raises:
        TypeError: If the type of ``branch_index`` is not Tensor.
        TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
        TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
        TypeError: If the elements of ``branch_fns`` is not 2-tuple.
        TypeError: If the first element of 2-tuple in ``branch_fns`` is not integer.
        ValueError: If the first element of 2-tuple in ``branch_fns`` is not unique.
        TypeError: If the second element of 2-tuple in ``branch_fns`` is not callable.
        TypeError: If ``default`` is not None but it is not callable.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> paddle.enable_static()

            >>> def fn_1():
            ...    return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)

            >>> def fn_2():
            ...    return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)

            >>> def fn_3():
            ...    return paddle.full(shape=[3], dtype='int32', fill_value=3)

            >>> startup_program = paddle.static.default_startup_program()
            >>> main_program = paddle.static.default_main_program()
            >>> with paddle.static.program_guard(main_program, startup_program):
            ...    index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1)
            ...    index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2)
            ...
            ...    out_1 = paddle.static.nn.switch_case(
            ...        branch_index=index_1,
            ...        branch_fns={1: fn_1, 2: fn_2},
            ...        default=fn_3)
            ...
            ...    out_2 = paddle.static.nn.switch_case(
            ...        branch_index=index_2,
            ...        branch_fns=[(1, fn_1), (2, fn_2)],
            ...        default=fn_3)
            ...
            ...    # Argument default is None and no index matches. fn_3 will be called because of the max index 7.
            ...    out_3 = paddle.static.nn.switch_case(
            ...        branch_index=index_2,
            ...        branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
            ...
            ...    exe = paddle.static.Executor(paddle.CPUPlace())
            ...    res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
            ...    # Variable: fill_constant_1.tmp_0
            ...    #   - message: The content of input layer:
            ...    #   - lod: {}
            ...    #   - place: Place(cpu)
            ...    #   - shape: [2, 3]
            ...    #   - layout: NCHW
            ...    #   - dtype: int64
            ...    #   - data: [3 3 3 3 3 3]

            >>> print(res_1)
            [[1. 1.]]

            >>> print(res_2)
            [[2 2]
             [2 2]]

            >>> print(res_3)
            [3 3 3]
    '''
    helper = LayerHelper('switch_case', **locals())

    def _check_args(branch_index, branch_fns, default):
        check_variable_and_dtype(
            branch_index,
            'branch_index',
            ['uint8', 'int32', 'int64'],
            'static.nn.switch_case',
        )

        if convert_dtype(branch_index.dtype) != "int64":
            branch_index = paddle.cast(branch_index, "int64")

        check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')

        branch_fns = (
            branch_fns.items() if isinstance(branch_fns, dict) else branch_fns
        )

        branch_fns = (
            list(enumerate(branch_fns))
            if all(callable(fn) for fn in branch_fns)
            else branch_fns
        )

        keys_of_fns = []
        for index_fn_pair in branch_fns:
            if not isinstance(index_fn_pair, tuple):
                raise TypeError(
                    _error_message(
                        "The elements' type",
                        "branch_fns",
                        "switch_case",
                        tuple,
                        type(branch_fns),
                    )
                )

            if len(index_fn_pair) != 2:
                raise TypeError(
                    _error_message(
                        "The tuple's size",
                        "branch_fns",
                        "switch_case",
                        "2",
                        str(len(index_fn_pair)) + "-tuple",
                    )
                )

            key, fn = index_fn_pair

            if not isinstance(key, int):
                raise TypeError(
                    _error_message(
                        "The key's type",
                        "branch_fns",
                        "switch_case",
                        int,
                        type(key),
                    )
                )

            if key in keys_of_fns:
                raise ValueError(
                    f"The key in 'branch_fns' must be unique, but '{key}' appears more than once."
                )
            else:
                keys_of_fns.append(key)

            if not callable(fn):
                raise TypeError(
                    _error_message(
                        f"The type of function for key {key}",
                        "branch_fns",
                        "switch_case",
                        "callable",
                        type(fn),
                    )
                )

        if default is None:
            default = sorted(branch_fns)[-1][1]
            branch_fns = sorted(branch_fns)[:-1]
        elif not callable(default):
            raise TypeError("The default in Op(case) must be callable.")

        pred_fn_pairs = []
        for index, fn in branch_fns:
            new_index = paddle.full(shape=[1], dtype="int64", fill_value=index)
            pred = paddle.equal(branch_index, new_index)
            pred_fn_pairs.append((pred, fn))

        return pred_fn_pairs, default

    pred_fn_pairs, default = _check_args(branch_index, branch_fns, default)
    false_fn = default
    for pred, true_fn in pred_fn_pairs:
        false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)

    final_fn = false_fn
    return final_fn()


def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
    """
    This API returns ``true_fn()`` if the predicate ``pred`` is true else
    ``false_fn()`` . Users could also set ``true_fn`` or ``false_fn`` to
    ``None`` if do nothing and this API will treat the callable simply returns
    ``None`` in this case.

    ``true_fn`` and ``false_fn`` should return same nest structure of tensors
    or both return ``None`` if user doens't like to return anything. A nest
    structure of tensors in PaddlePaddle is tensor(s), or tuple of tensors, or
    list of tensors.

    Note:
        1. The tuples or lists returned by ``true_fn`` and ``false_fn`` must have
        the same shape because of dataflow model of PaddlePaddle while the
        tensors in the tuples or the lists can have different shapes.

        2. This API could be used under both static graph mode or dygraph mode. If it
        is in dygraph mode, the API only runs one branch based on condition.

        3. If it is in static graph mode, any tensors or operations created outside
        or inside of ``true_fn`` and ``false_fn`` will be in net building
        regardless of which branch is selected at runtime. This has frequently
        surprised users who expected a lazy semantics.

        Examples:
            .. code-block:: python
                :name: code-example-1

                >>> import paddle

                >>> a = paddle.zeros((1, 1))
                >>> b = paddle.zeros((1, 1))
                >>> c = a * b
                >>> out = paddle.static.nn.cond(a < b, lambda: a + c, lambda: b * b)

        No matter whether ``a < b`` , ``c = a * b`` will be in net building and
        run. ``a + c`` and ``b * b`` will be in net building, but only one
        branch will be executed during runtime.

    Args:
        pred(Tensor): A boolean tensor whose numel should be 1 (shape []
            or shape [1]). The boolean value determines whether to return the
            result of ``true_fn`` or ``false_fn`` .
        true_fn(callable, optional): A callable to be performed if ``pred`` is
            true. The default value is ``None`` .
        false_fn(callable, optional): A callable to be performed if ``pred`` is
            false. The default value is ``None`` .
        name(str, optional): The default value is ``None`` . Normally users
             don't have to set this parameter. For more information, please
             refer to :ref:`api_guide_Name` .
        return_names(sequence of string, optional): The default value is ``None`` .
             Normally users don't have to set this parameters.  A sequence of strings
             to represents the name of returned vars.  The structure of sequence must
             be same with return values of true_fn and false_fn.

    Returns:
        Tensor|list(Tensor)|tuple(Tensor): returns ``true_fn()`` if the
        predicate ``pred`` is true else ``false_fn()`` .

    Examples:
        .. code-block:: python
            :name: code-example-2

            >>> import paddle

            >>> # pseudocode:
            >>> # if 0.1 < 0.23:
            >>> #     return 1, True
            >>> # else:
            >>> #     return 3, 2

            >>> def true_func():
            ...     return paddle.full(shape=[1, 2],
            ...                        dtype='int32',
            ...                        fill_value=1
            ...         ), paddle.full(shape=[2, 3],
            ...                        dtype='bool',
            ...                        fill_value=True
            ...         )


            >>> def false_func():
            ...     return paddle.full(shape=[3, 4],
            ...                        dtype='float32',
            ...                        fill_value=3
            ...         ), paddle.full(shape=[4, 5],
            ...                        dtype='int64',
            ...                        fill_value=2
            ...         )


            >>> x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            >>> y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            >>> pred = paddle.less_than(x=x, y=y, name=None)
            >>> a, b = paddle.static.nn.cond(pred, true_func, false_func)

            >>> print(a)
            Tensor(shape=[1, 2], dtype=int32, place=Place(cpu), stop_gradient=True,
                   [[1, 1]])
            >>> print(b)
            Tensor(shape=[2, 3], dtype=bool, place=Place(cpu), stop_gradient=True,
                   [[True, True, True],
                    [True, True, True]])
    """
    if in_dygraph_mode():
        assert isinstance(pred, Variable), "The pred in cond must be Variable"
        assert pred.size == 1, "condition input's numel should be 1"
        pred = pred.item()
        if pred:
            if true_fn is not None:
                if not callable(true_fn):
                    raise TypeError(
                        "The true_fn in cond must be callable, but received {}".format(
                            type(true_fn).__name__
                        )
                    )
                return true_fn()
        else:
            if false_fn is not None:
                if not callable(false_fn):
                    raise TypeError(
                        "The false_fn in cond must be callable, but received {}".format(
                            type(false_fn).__name__
                        )
                    )
                return false_fn()
        return None
    true_output = None
    false_output = None
    if in_pir_mode():
        check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
        check_type(name, "name", (str, type(None)), "base.layers.cond")
        if_op = build_if_op(pred)
        if true_fn is not None:
            if not callable(true_fn):
                raise TypeError(
                    "The true_fn in cond must be callable, but received {}".format(
                        type(true_fn).__name__
                    )
                )
            with if_op.true_block():
                true_output = true_fn()
        if false_fn is not None:
            if not callable(false_fn):
                raise TypeError(
                    "The false_fn in cond must be callable, but received {}".format(
                        type(false_fn).__name__
                    )
                )
            with if_op.false_block():
                false_output = false_fn()
    else:
        check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
        check_type(name, "name", (str, type(None)), "base.layers.cond")
        helper = LayerHelper('cond', **locals())
        copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper)
        if true_fn is not None:
            if not callable(true_fn):
                raise TypeError(
                    "The true_fn in cond must be callable, but received {}".format(
                        type(true_fn).__name__
                    )
                )
            true_cond_block = ConditionalBlock([pred], is_scalar_condition=True)
            with true_cond_block.block():
                origin_true_output = true_fn()
                if origin_true_output is not None:
                    true_output = map_structure(
                        copy_to_parent_func, origin_true_output
                    )
        if false_fn is not None:
            if not callable(false_fn):
                raise TypeError(
                    "The false_fn in cond must be callable, but received {}".format(
                        type(false_fn).__name__
                    )
                )
            false_cond_block = ConditionalBlock(
                [paddle.logical_not(pred)], is_scalar_condition=True
            )
            with false_cond_block.block():
                origin_false_output = false_fn()
                if origin_false_output is not None:
                    false_output = map_structure(
                        copy_to_parent_func, origin_false_output
                    )

    if true_output is None and false_output is None:
        return None

    if true_output is None:
        raise ValueError(
            "Incompatible return values of true_fn and false_fn in cond: "
            "true_fn returns None while false_fn returns non-None"
        )
    if false_output is None:
        raise ValueError(
            "Incompatible return values of true_fn and false_fn in cond: "
            "true_fn returns non-None while false_fn returns None"
        )

    # Merge true and false output if they are not None
    if return_names is None:
        is_dy2staic = False
        return_names = ["no name"] * len(_to_sequence_except_dict(true_output))
    else:
        """
        dy2static will set the return_names and expand the return values to UndefinedVar.
        """
        is_dy2staic = True

        # TODO:  expand_undefined_var will replace None to Undefinedvar(), to fix cases like:
        #       a = None
        #       if condition:
        #           a = 1
        # Because we can not use variable to express 'None'
        true_output, false_output = expand_undefined_var(
            true_output, false_output, return_names
        )

    if len(_to_sequence_except_dict(true_output)) != len(
        _to_sequence_except_dict(false_output)
    ):
        raise ValueError(
            "true fn returns {} vars, but false fn returns {} vars, which is not equals".format(
                len(_to_sequence_except_dict(true_output)),
                len(_to_sequence_except_dict(false_output)),
            )
        )
    for true_out, false_out, return_name in zip(
        _to_sequence_except_dict(true_output),
        _to_sequence_except_dict(false_output),
        _to_sequence_except_dict(return_names),
    ):
        try:
            assert_same_structure(true_out, false_out, check_types=False)
        except ValueError as e:
            raise ValueError(
                "Incompatible return values of `{}` in true_fn and false_fn in cond: {}".format(
                    return_name, e
                )
            )

    def check_ret_none(seq_true, seq_false, seq_names):
        for f_true, f_false, f_name in zip(seq_true, seq_false, seq_names):
            f_true = flatten(f_true)
            f_false = flatten(f_false)
            for idx in range(len(f_true)):
                if (
                    f_true[idx] is None
                    and f_false[idx] is not None
                    or f_false[idx] is None
                    and f_true[idx] is not None
                ):
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
                            f_name,
                            type(f_true[idx]),
                            f_true[idx],
                            type(f_false[idx]),
                            f_false[idx],
                        )
                    )

    check_ret_none(
        _to_sequence_except_dict(true_output),
        _to_sequence_except_dict(false_output),
        _to_sequence_except_dict(return_names),
    )

    if is_dy2staic:
        true_output, false_output = change_none_to_undefinedvar(
            true_output, false_output
        )

    if in_pir_mode():
        with if_op.true_block():
            cf_yield(flatten(true_output))
        with if_op.false_block():
            cf_yield(flatten(false_output))
        if_op.update_output()
        return pack_sequence_as(true_output, flatten(if_op.results()))

    mask = paddle.cast(pred, dtype='int32')
    merge_func = (
        lambda name, false_var, true_var: select_input_with_buildin_type(
            [false_var, true_var], mask, name
        )
    )

    def merge_every_var_list(false_vars, true_vars, name):
        return map_structure(partial(merge_func, name), false_vars, true_vars)

    merged_output_fns = list(
        map(
            merge_every_var_list,
            _to_sequence_except_dict(false_output),
            _to_sequence_except_dict(true_output),
            _to_sequence_except_dict(return_names),
        )
    )
    merged_output = map_structure(lambda fn: fn(), merged_output_fns)
    merged_output = pack_sequence_as(false_output, flatten(merged_output))
    return merged_output


def copy_var_to_parent_block(var, layer_helper):
    if not isinstance(var, Variable):
        return var
    prog = layer_helper.main_program
    parent_idx = prog.current_block().parent_idx
    assert (
        parent_idx >= 0
    ), "Got wrong parent block index when assigning var to parent scope in control_flow"
    parent_block = prog.block(parent_idx)

    if (
        var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
        and parent_block._find_var_recursive(var.name)
    ):
        parent_block_var = var
    else:
        parent_block_var = parent_block.create_var(
            dtype=var.dtype, shape=var.shape, type=var.type
        )
        paddle.assign(var, parent_block_var)
    return parent_block_var


def select_output(input, outputs, mask):
    """
    **select_output**
    This API takes in one input and multiple outputs and an integer mask. It
    selects the output specified by the mask and copy the input to selected
    output. It is useful in control flow.

    Args:
        input(Variable): The input variable
        outputs(tuple|list): The output variables
        mask(Variable): A tensor containing 1 integer number selecting which
            output to be copied with input

    Returns:
        Variable: The outputs variables
    """
    helper = LayerHelper('select_output', **locals())
    check_type(input, 'input', (Variable), 'select_output')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
    check_type(outputs, 'outputs', (list, tuple), 'select_output')

    helper.append_op(
        type='select_output',
        inputs={'X': input, 'Mask': mask},
        outputs={'Out': outputs},
    )
    return outputs


def _select_input_infer_shape(first_shape, second_shape):
    """
    This function infer the output shape by following algorithm:
    1. if the dims is different, raise a error.
    2. compare axis one by one:
        if a == b: we set axis to a
        if a != b: we set axis to -1
    for compatibility, non declarative mode, we just return second_shape.
    """
    if len(first_shape) != len(second_shape):
        warnings.warn(
            f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
        )
        return second_shape
    out_shape = [a if a == b else -1 for a, b in zip(first_shape, second_shape)]
    return out_shape


def select_input(inputs, mask):
    """
    **select_input**

    This API takes in multiple inputs and uses an integer mask to select one
    input to output. It is useful in control flow.

    Args:
        inputs(tuple|list): The input variables
        mask(Tensor): A tensor containing 1 integer number selecting which
            input to output

    Returns:
        Variable: The selected input variable
    """
    helper = LayerHelper('select_input', **locals())
    check_type(inputs, 'inputs', (list, tuple), 'select_input')
    check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

    # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
    # assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"

    output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
    output_dtype = inputs[1].dtype
    output_type = inputs[1].type

    out = helper.create_variable(
        dtype=output_dtype, shape=output_shape, type=output_type
    )
    helper.append_op(
        type='select_input',
        inputs={'X': inputs, 'Mask': mask},
        outputs={'Out': out},
    )
    return out


def select_input_with_buildin_type(inputs, mask, name):
    from paddle.jit.dy2static.utils import UndefinedVar
    from paddle.jit.dy2static.variable_trans_func import to_static_variable

    false_var, true_var = inputs

    def start_select_input():
        try:
            return select_input(inputs, mask)
        except Exception as e:
            raise RuntimeError(
                f"Exceptions throwed while doing select_input on {name}:\n{e}"
            )

    if isinstance(false_var, UndefinedVar) and isinstance(
        true_var, UndefinedVar
    ):
        """None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None."""
        return lambda: None

    if isinstance(false_var, Variable) and isinstance(true_var, Variable):
        return start_select_input

    elif isinstance(false_var, support_ret_buildin_type) and isinstance(
        false_var, type(true_var)
    ):
        if false_var == true_var:
            return lambda: false_var
        else:
            inputs = [
                to_static_variable(false_var),
                to_static_variable(true_var),
            ]
    # Deal with the situations like this: false_var is int and true_var is Variable
    elif (
        isinstance(false_var, support_ret_buildin_type)
        and isinstance(true_var, Variable)
    ) or (
        isinstance(true_var, support_ret_buildin_type)
        and isinstance(false_var, Variable)
    ):
        inputs = [to_static_variable(false_var), to_static_variable(true_var)]
        warnings.warn(
            "Return results from different branches in cond are not same type: "
            "false_var returned by false_fn is '{}' and true_var of true_fn is "
            "'{}'".format(type(false_var), type(true_var))
        )
    elif (
        isinstance(false_var, UndefinedVar)
        and isinstance(true_var, (Variable,) + support_ret_buildin_type)
    ) or (
        isinstance(true_var, UndefinedVar)
        and isinstance(false_var, (Variable,) + support_ret_buildin_type)
    ):
        true_var, false_var = to_static_variable(true_var), to_static_variable(
            false_var
        )
        inputs = [false_var, true_var]
    else:
        raise TypeError(
            "Unsupported return type of true_fn and false_fn in cond: false_var "
            "returned by false_fn is '{}' and true_var of true_fn is '{}'".format(
                type(false_var), type(true_var)
            )
        )
    return start_select_input


def _is_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return False
    return is_sequence(x)


def _to_sequence_except_dict(x):
    """
    In this function, dict is not viewed as sequence.
    """
    if isinstance(x, dict):
        return [x]
    return to_sequence(x)


def expand_undefined_var(nest1, nest2, names):
    """TODO: make this function recursively.
    nest1: Var1, (UndefinedVar, [1,2,3])
    nest2: Var2, ([1,2,3,4], UndefinedVar)
    In this case, we should not expand recursively.
    """
    from paddle.jit.dy2static.return_transformer import RETURN_VALUE_PREFIX
    from paddle.jit.dy2static.utils import UndefinedVar

    def pack_undefined_var_as(seq):
        return pack_sequence_as(
            seq, [UndefinedVar("padding") for i in flatten(seq)]
        )

    def map_fn(n1, n2, name, order):
        if not name.startswith(RETURN_VALUE_PREFIX) and (
            isinstance(n1, UndefinedVar) or n1 is None
        ):
            if n1 is None and n2 is not None:
                if order == 0:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
                            name, type(n1), n1, type(n2), n2
                        )
                    )
                else:
                    warnings.warn(
                        "In cond : Var '{}' or part of it is set differently in ifelse branchs, "
                        "<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
                        "'None' in ifelse block might lead to error.".format(
                            name, type(n2), n2, type(n1), n1
                        )
                    )
            return pack_undefined_var_as(n2)
        return n1

    nest1_out = list(
        map(
            map_fn,
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(names),
            [0 for i in _to_sequence_except_dict(names)],
        )
    )
    nest2_out = list(
        map(
            map_fn,
            _to_sequence_except_dict(nest2),
            _to_sequence_except_dict(nest1),
            _to_sequence_except_dict(names),
            [1 for i in _to_sequence_except_dict(names)],
        )
    )
    if not _is_sequence_except_dict(nest1):
        nest1_out = nest1_out[0]
    if not _is_sequence_except_dict(nest2):
        nest2_out = nest2_out[0]
    return nest1_out, nest2_out


def change_none_to_undefinedvar(nest1, nest2):
    from paddle.jit.dy2static.utils import UndefinedVar

    def map_fn(x):
        if x is None:
            return UndefinedVar("padding")
        return x

    nest1_out = pack_sequence_as(nest1, list(map(map_fn, flatten(nest1))))
    nest2_out = pack_sequence_as(nest2, list(map(map_fn, flatten(nest2))))
    return nest1_out, nest2_out


@static_only
def Print(
    input,
    first_n=-1,
    message=None,
    summarize=20,
    print_tensor_name=True,
    print_tensor_type=True,
    print_tensor_shape=True,
    print_tensor_layout=True,
    print_tensor_lod=True,
    print_phase='both',
):
    '''
    :api_attr: Static Graph

    **Print operator**

    This creates a print op that will print when a tensor is accessed.

    Wraps the tensor passed in so that whenever that a tensor is accessed,
    the message `message` is printed, along with the current value of the
    tensor `t`.

    Args:
        input (Tensor): A Tensor to print.
        first_n (int, optional): Only log `first_n` number of times. Default: -1.
        message (str, optional): A string message to print as a prefix. Default: None.
        summarize (int, optional): Number of elements in the tensor to be print. If
                it's value is -1, then all elements in the tensor will be print.
        print_tensor_name (bool, optional): Print the tensor name. Default: True.
        print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
        print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
        print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
        print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
        print_phase (str, optional): Which phase to displace, including 'forward',
                'backward' and 'both'. Default: 'both'. If set to 'backward', will
                only print the gradients of input tensor; If set to 'both', will
                both print the input tensor itself and the gradients of input tensor.

    Returns:
        Tensor: Output tensor.

    NOTES:
        The input and output are two different Tensor, and in the
        following process, you should use the output Tensor but not the input,
        otherwise, the print layer doesn't have backward.

    Examples:
        .. code-block:: python

            >>> import paddle

            >>> paddle.enable_static()

            >>> x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
            >>> out = paddle.static.Print(x, message="The content of input layer:")

            >>> main_program = paddle.static.default_main_program()
            >>> exe = paddle.static.Executor(place=paddle.CPUPlace())
            >>> res = exe.run(main_program, fetch_list=[out])
            >>> # doctest: +SKIP('Unable to get output')
            Variable: fill_constant_1.tmp_0
              - message: The content of input layer:
              - lod: {}
              - place: Place(cpu)
              - shape: [2, 3]
              - layout: NCHW
              - dtype: int64
              - data: [3 3 3 3 3 3]
            >>> # doctest: -SKIP
            >>> res
            [array([[3, 3, 3],
                    [3, 3, 3]], dtype=int64)]
    '''
    check_variable_and_dtype(
        input,
        'input',
        ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
        'paddle.static.Print',
    )
    message = message or ""
    helper = LayerHelper('print', **locals())

    if in_pir_mode():
        return _C_ops.print(
            input,
            first_n,
            message,
            summarize,
            print_tensor_name,
            print_tensor_type,
            print_tensor_shape,
            print_tensor_layout,
            print_tensor_lod,
            print_phase.upper(),
            True,
        )

    output = helper.create_variable_for_type_inference(input.dtype)
    helper.append_op(
        type='print',
        inputs={'In': input},
        outputs={'Out': output},
        attrs={
            'first_n': first_n,
            'summarize': summarize,
            'message': message or "",
            'print_tensor_name': print_tensor_name,
            'print_tensor_type': print_tensor_type,
            'print_tensor_shape': print_tensor_shape,
            'print_tensor_layout': print_tensor_layout,
            'print_tensor_lod': print_tensor_lod,
            'print_phase': print_phase.upper(),
        },
    )
    return output


class Switch:
    def __init__(self, name=None):
        self.helper = LayerHelper('switch', name=name)
        self.inside_scope = False
        self.pre_not_conditions = []

    def case(self, condition):
        if not self.inside_scope:
            raise ValueError("case should be called inside with")

        check_variable_and_dtype(
            condition,
            'condition',
            ['bool'],
            'the member function case of base.layers.Switch',
        )

        if len(self.pre_not_conditions) == 0:
            cond_block = ConditionalBlock([condition], is_scalar_condition=True)
            not_cond = paddle.logical_not(x=condition)
            self.pre_not_conditions.append(not_cond)
        else:
            pre_cond_num = len(self.pre_not_conditions)
            pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
            new_not_cond = paddle.logical_and(
                x=pre_not_cond, y=paddle.logical_not(x=condition)
            )
            self.pre_not_conditions.append(new_not_cond)
            cond_block = ConditionalBlock(
                [paddle.logical_and(x=pre_not_cond, y=condition)],
                is_scalar_condition=True,
            )

        return ConditionalBlockGuard(cond_block)

    def default(self):
        pre_cond_num = len(self.pre_not_conditions)
        if pre_cond_num == 0:
            raise ValueError("there should be at least one condition")
        cond_block = ConditionalBlock(
            [self.pre_not_conditions[pre_cond_num - 1]],
            is_scalar_condition=True,
        )
        return ConditionalBlockGuard(cond_block)

    def __enter__(self):
        """
        set flag that now is inside switch.block {}
        :return:
        """
        self.inside_scope = True
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.inside_scope = False
        if exc_type is not None:
            return False  # re-raise exception

        return True
