o
    "j<                     @   s@  d dl Z d dlmZ d dlZd dlZd dlmZ d dlmZm	Z	 d dl
mZmZmZmZ d dlmZ d dlmZmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ ddlmZ ddlm Z m!Z! ddl"m#Z#m$Z$m%Z% g Z&G dd dZ'G dd dZ(G dd dZ)G dd dZ*G dd dZ+dddZ,e	d ddZ-dS )!    N)deepcopy)_legacy_C_ops)_in_amp_guard_in_pure_fp16_guard)backwardcore	frameworkprogram_guard)BuildStrategy)
check_typeconvert_dtype)switch_to_static_graph)_apply_pass	get_flags)switch)LRScheduler   )logging_utils)SubGraphRolepir_exporter)RETURN_NO_VALUE_MAGIC_NUMbackend_guardconstruct_grad_namesc                   @   sN   e Zd ZdZdddZdd Zdd Zd	d
 Zdd Ze	dd Z
dd ZdS )NestSequencezf
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    Fc                 C   s(   || _ |  | _|  | _| | d S N)_NestSequence__raw_inputtolist_NestSequence__input_list_get_var_ids_NestSequence__var_ids_check_non_variable)self	raw_input
need_check r$   e/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/paddle/jit/dy2static/partial_program.py__init__0   s   

zNestSequence.__init__c                 C   s   t j| jS )zA
        Flattens the nested sequences into single list.
        )paddleutilsflattenr   r!   r$   r$   r%   r   6   s   zNestSequence.tolistc                 C   s&   t | jt |ksJ tj| j|S )z?
        Restores the nested sequence from value list.
        )lenr   r'   r(   Zpack_sequence_asr   )r!   Z
value_listr$   r$   r%   restore<   s   zNestSequence.restorec                 C   s:   g }t | jD ]\}}t|tjtjjfr|| q|S r   )		enumerater   
isinstancer   Variabler   eagerTensorappend)r!   var_idsidxvarr$   r$   r%   r   C   s   
zNestSequence._get_var_idsc                 C   s\   |r*t  }| jD ]}t|tjtjjfs|t	| q|r,t
dt| dS dS dS )z^
        Raises warning if output of traced function contains non-tensor type values.
        zOutput of traced function contains non-tensor type values: {}. Currently, We don't support to update them while training and will return what we first saw. Please try to return them as tensor.N)setr   r.   r   r/   r   r0   r1   addtyper   warnformatlist)r!   r#   Zwarning_typesr5   r$   r$   r%   r    K   s   
z NestSequence._check_non_variablec                 C   s   | j S r   )r   r*   r$   r$   r%   r3   ]   s   zNestSequence.var_idsc                 C   s
   | j | S r   )r   )r!   itemr$   r$   r%   __getitem__a      
zNestSequence.__getitem__NF)__name__
__module____qualname____doc__r&   r   r,   r   r    propertyr3   r=   r$   r$   r$   r%   r   *   s    

r   c                   @       e Zd ZdZdd Zdd ZdS )LazyInitializedzB
    Descriptor to implement lazy initialization of property.
    c                 C   
   || _ d S r   )function)r!   rH   r$   r$   r%   r&   j   r>   zLazyInitialized.__init__c                 C   s   |  |}t|| j j| |S r   )rH   setattrr@   )r!   instanceclsvalr$   r$   r%   __get__m   s   
zLazyInitialized.__get__N)r@   rA   rB   rC   r&   rM   r$   r$   r$   r%   rF   e   s    rF   c                   @   rE   )ProgramInfoz7
    A helper class to recoder Program information
    c                 C   s   dddd| _ i | _d| _d S )Nfp32ampfp16Zinfer)op_sizeprogramsmoder*   r$   r$   r%   r&   x   s   
zProgramInfo.__init__c                 C   sT   |dv sJ || j vr |dd}|| j |< |jd | j|< | j | | j| fS )z4
        Recoder infer program and op size.
        rP   T)is_infer_moder   )rU   descblockrT   )r!   keyZprog_creatorZ
infer_progr$   r$   r%   __call__   s   


zProgramInfo.__call__N)r@   rA   rB   rC   r&   r[   r$   r$   r$   r%   rN   s   s    	rN   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )PartialProgramLayerHookc                 C      d S r   r$   )r!   forward_programr$   r$   r%   before_append_backward      z.PartialProgramLayerHook.before_append_backwardc                 C   r]   r   r$   )r!   whole_programZbackward_start_idxr$   r$   r%   after_append_backward   r`   z-PartialProgramLayerHook.after_append_backwardc                 C   r]   r   r$   r!   infer_programr$   r$   r%   after_infer   r`   z#PartialProgramLayerHook.after_inferN)r@   rA   rB   r_   rb   re   r$   r$   r$   r%   r\      s    r\   c                       s  e Zd ZdZ	dz fdd	Zdd Zdd Zd	d
 Zdd Zdd Z	d{ddZ
ed|ddZed|ddZed|ddZedd Zedd Zedd Zedd Zed d! Zed"d# Zed$d% Zed&d' Zed(d) Zed*d+ Zed,d- Zed.d/ Zed0d1 Zed2d3 Zed4d5 Zed6d7 Zed8d9 Z ed:d; Z!ed<d= Z"d>d? Z#e$d@dA Z%e$dBdC Z&e$dDdE Z'e$dFdG Z(e$dHdI Z)e$dJdK Z*dLdM Z+dNdO Z,edPdQ Z-dRdS Z.dTdU Z/d|dVdWZ0edXdY Z1edZd[ Z2d\d] Z3ed^d_ Z4dzd`daZ5dbdc Z6ddde Z7d{dfdgZ8dhdi Z9djdk Z:dldm Z;edndo Z<dpdq Z=drds Z>dtdu Z?dvdw Z@dxdy ZA  ZBS )}PartialProgramLayeraa  
    PartialProgramLayer wraps all the ops from layers decorated by `@to_static`
    and execute them as a static subgraph.

    .. note::
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
        inputs(list[Variable]): The input list of the decorated function by `@to_static`.
        outputs(list[Variable]): The output list of the decorated function by `@to_static`.
        parameters(list[Tensor]|None): All trainable parameters included in the program. Default None.

    Returns:
        Layer: A Layer object that run all ops internally in static graph mode.
    Nc                    s  t    t| _t|dd _|d ur|ng  _| _|dt  _	t
 j	ts-J  | _tjjtjj     _W d    n1 sNw   Y  d _d _d _t  _i  _d\}}}	t }
|
rw|
 \}}	|
j}|d ur|dv rtjjj j!||	|d _"i  _#i  _$d  _%|d	d  _&i  _'g  _( jD ]}t
|tj)r j(*|j+,  q fd
d jj-D  _.d  _/d S )NT)r#   build_strategy r   )NNN)float16Zbfloat16)custom_white_listcustom_black_listdtypebackendc                    s   g | ]} j | jqS r$   )_outputsrX   ).0var_idr*   r$   r%   
<listcomp>   s    z0PartialProgramLayer.__init__.<locals>.<listcomp>)0superr&   r   _inputsrn   _params_name_generatorgetr
   _build_strategyr.   _verify_program_origin_main_programr'   baser   Z_dygraph_guardZdygraphZTracer_create_cuda_graph_vec_cuda_graph_vec_cuda_graph_capture_mode_cuda_graph_pool_idtrainingrN   _infer_info_forward_end_index_mapZ_dygraph_tracerZ_get_amp_op_listZ
_amp_dtypestaticrR   Z
fp16_listsZAutoMixedPrecisionLists	_amp_list_pir_scope_cache_legacy_scope_cache_hooker_backend_grad_var_names_in_var_namesr/   r2   rX   namer3   _out_var_descsZ_debug_name)r!   main_programinputsoutputsname_generator
parameterskwargsZ	amp_dtyperj   rk   Ztracerr5   	__class__r*   r%   r&      sT   
	


	


zPartialProgramLayer.__init__c           
      C   s   t | j\}}| |\}}|  }| | tjdddk}| j| d}|	d|g | 
  tj| || | j| || j| jdd| jg|R   | |}	| |	}	t || |	S )zQ
        Execute static graph by Interpreter and Return dynamic Tensors.
        Z
DY2ST_TESTNTrueforce_not_use_ptx_namesT
program_iduse_scope_cache)r   ru   _prepare_inputs_prepare_outputs_cast_fp16_if_pure_fp16osenvironrv   _prepare_attributesextend_sync_lr_value_with_schedulerr   run_program_valid_varsrt   _create_scope_vecr   r|   _restore_out_remove_no_value)
r!   r   old_generatorold_para_name_checkerin_varsin_var_namesout_varsZis_dy2st_testattrsrestored_nest_outr$   r$   r%   r[      s.   




zPartialProgramLayer.__call__c           	      C   s   t | j\}}| |\}}|  }| | | jdd}|d|g |   tj	| 
|| 
| j| 
|| j| jdd| jg|R   | |}| |}t || |S )z
        Same as __call__, but set force_not_use_pt to False.
        Currently _sot_call will cause CUDA 700 error, so we disable it temporarily.
        Fr   r   Tr   )r   ru   r   r   r   r   r   r   r   r   r   rt   r   r   r|   r   r   )	r!   r   r   r   r   r   r   r   r   r$   r$   r%   sot_call  s,   




zPartialProgramLayer.sot_callc                 C   s   t | j\}}|  }| | |  }|d| jg |   tj	| 
|| 
| j| 
|| j| jdd| jg|R   t || |S )zz
        In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up
        r   Tr   )r   ru   r   r   r   r   r   r   r   r   r   rt   r   r   r|   )r!   r   r   r   r   r   r$   r$   r%   	_sot_call+  s&   


zPartialProgramLayer._sot_callc                 C   sr   | j }t|dr5t|dr7|j}|j}t|tsJ d| j j}| }t|t	|j
}|| dS dS dS )z4Update lr_var value with calculated by lr_scheduler.lr_schedulerlr_varzmust be LRSchedulerN)ry   hasattrr   r   r.   r   nparrayastyper   rl   	set_value)r!   r   r   r   Zlr_valuedatar$   r$   r%   r   E  s   z1PartialProgramLayer._sync_lr_value_with_schedulerc                 C   rG   r   )r   )r!   Zhookerr$   r$   r%   
set_hookerT  r>   zPartialProgramLayer.set_hookerFc                 C   sx   t dd st dd r| j}n| j}|st S ||vr!g ||< || }|D ]	}|jr0|  S q't }|| |S )NFLAGS_enable_pir_in_executor!FLAGS_enable_pir_with_pt_in_dy2st)r   r   r   r   ScopeZ_can_reusedr2   )r!   r   r   Z_scope_cacheZcached_scopesscoper$   r$   r%   
_get_scopeW  s,   
zPartialProgramLayer._get_scopec                 C   sF   |r| j j|d}| jr| j|}|S | | j }| | j| |S )NZfor_test)ry   cloner   re   _append_backward_desc_set_grad_typert   )r!   rW   rd   train_programr$   r$   r%   _create_programp  s   z#PartialProgramLayer._create_programc                 C   s   | j j|d}t| tjjjj|| jddd W d    n1 s#w   Y  |r5| j	r3| j	
|}|S | |}| | j| |S )Nr   FZO1)use_fp16_guardlevelry   r   r	   r'   r   rR   Z
fp16_utilsZcast_model_to_fp16r   r   re   r   r   rt   )r!   rW   Zamp_programZtrain_amp_programr$   r$   r%   _create_amp_program  s   



z'PartialProgramLayer._create_amp_programc                 C   s   | j j|d}t| tjjjj|| jdd W d    n1 s"w   Y  |r4| j	r2| j	
|}|S | |}| | j| |S )Nr   F)r   r   )r!   rW   Zpure_fp16_programZtrain_pure_fp16_programr$   r$   r%   _create_pure_fp16_program  s"   

z-PartialProgramLayer._create_pure_fp16_programc                 C   (   | j }| |}|dksJ | ||S Nr   )_train_programget_forward_end_op_idx"_get_forward_backward_program_formr!   ra   forward_end_op_indexr$   r$   r%   &_create_forward_backward_train_program     
z:PartialProgramLayer._create_forward_backward_train_programc                 C   r   r   )_train_amp_programr   r   r   r$   r$   r%   *_create_forward_backward_train_amp_program  r   z>PartialProgramLayer._create_forward_backward_train_amp_programc                 C   r   r   )_train_pure_fp16_programr   r   r   r$   r$   r%   0_create_forward_backward_train_pure_fp16_program  r   zDPartialProgramLayer._create_forward_backward_train_pure_fp16_programc                 C      |   S r   )r   r*   r$   r$   r%   r        z"PartialProgramLayer._train_programc                 C      |  d| j\}}| ||S )NrQ   )r   r   _build_infer_programr!   programrT   r$   r$   r%   _infer_program     z"PartialProgramLayer._infer_programc                 C   r   r   )r   r*   r$   r$   r%   r     r   z&PartialProgramLayer._train_amp_programc                 C   r   )NrR   )r   r   r   r   r$   r$   r%   _infer_amp_program  r   z&PartialProgramLayer._infer_amp_programc                 C   r   r   )r   r*   r$   r$   r%   r     r   z,PartialProgramLayer._train_pure_fp16_programc                 C   r   )NrS   )r   r   r   r   r$   r$   r%   _infer_pure_fp16_program  s   z,PartialProgramLayer._infer_pure_fp16_programc                 C      |   }|S r   )r   r!   r   r$   r$   r%   _train_forward_backward_program     z3PartialProgramLayer._train_forward_backward_programc                 C   r   r   )r   r   r$   r$   r%   #_train_amp_forward_backward_program  r   z7PartialProgramLayer._train_amp_forward_backward_programc                 C   s
   t j S r   )r'   r   Programr*   r$   r$   r%    _empty_backward_program_for_eval  s   
z4PartialProgramLayer._empty_backward_program_for_evalc                 C   r   r   )r   r   r$   r$   r%   )_train_pure_fp16_forward_backward_program  r   z=PartialProgramLayer._train_pure_fp16_forward_backward_programc                 C   "   t j| j| }t|| j |S r   )r'   r(   _hash_with_idr   r   #_set_cached_executor_build_strategyrw   r!   r   r$   r$   r%   _train_program_id  
   z%PartialProgramLayer._train_program_idc                 C      t j| j| S r   )r'   r(   r   r   r*   r$   r$   r%   _infer_program_id     z%PartialProgramLayer._infer_program_idc                 C   r   r   )r'   r(   r   r   r   r   rw   r   r$   r$   r%   _train_amp_program_id   r   z)PartialProgramLayer._train_amp_program_idc                 C   r   r   )r'   r(   r   r   r*   r$   r$   r%   _infer_amp_program_id  r   z)PartialProgramLayer._infer_amp_program_idc                 C   r   r   )r'   r(   r   r   r   r   rw   r   r$   r$   r%   _train_pure_fp16_program_id  s   z/PartialProgramLayer._train_pure_fp16_program_idc                 C   r   r   )r'   r(   r   r   r*   r$   r$   r%   _infer_pure_fp16_program_id  r   z/PartialProgramLayer._infer_pure_fp16_program_idc                 C   s   | j tj||  S r   )r   r'   r(   r   r   r$   r$   r%   r     s   z*PartialProgramLayer.get_forward_end_op_idxc                 C   s   | j r| jS | jS )z7
        Return current train or eval program.
        )r   r   rd   r*   r$   r$   r%   r     s   zPartialProgramLayer.programc                 C   sB   | j rt r	| jS t r| jS | jS t r| jS t r| jS | jS )z?
        Return current train or eval program hash id.
        )	r   r   r   r   r   r   r   r   r   r*   r$   r$   r%   r   )  s   zPartialProgramLayer.program_idc                 C   s   t  r| jS t r| jS | jS r   )r   r   r   r   r   r*   r$   r$   r%   r   =  s
   z!PartialProgramLayer.train_programc                 C   s4   t  r| j}n
t r| j}n| j}t| |tj |S r   )r   r   r   r   r   r   r   ZInferrc   r$   r$   r%   rd   F  s   z!PartialProgramLayer.infer_programc                 C   sN   d\}}| j r"t r| j}|d S t r| j}|d S | j}|d S | j}|S )NNNr   )r   r   r   r   r   r   rd   )r!   r^   Zroleprogsr$   r$   r%   r^   S  s   z#PartialProgramLayer.forward_programc                 C   sD   | j rt r| j}|d S t r| j}|d S | j}|d S 	 | jS Nr   )r   r   r   r   r   r   r   )r!   r   r$   r$   r%   backward_programb  s   	z$PartialProgramLayer.backward_programc                 C   s   |  | | | |S )z
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        )_check_params_all_inited_prune_unused_paramsr!   r   r$   r$   r%   rx   x  s   

z#PartialProgramLayer._verify_programc           	         sT    fdd}fdd}t t|| j }|D ]}| |j}||| qdS )a  
        Why we need add gradient aggregation operation ?
        In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
        def forward(self, in):
            x = 2 * in  # <---- x is a non-leaf node in program.
            y = x + 3
            return x, y

        loss = forward(in)[0].sum()
        loss.backward()  # <----- x@grad will be overwrited by elementwise_add_grad Op
        c                    st   t | tjr| jtjjjtjjjfvrdS | j	t
jt
jfvr dS  djD ]}|jD ]}|| jkr6  dS q+q&dS )zM
            if exist a op whose inputs is var, then return True
            Fr   T)r.   r   r/   r8   r   VarDescVarTypeZ
LOD_TENSORZSELECTED_ROWSrl   r'   Zfloat32Zfloat64rY   opsinput_arg_namesr   )r5   opZin_arg)r   r$   r%   _need_aggregation  s   

zKPartialProgramLayer.prepare_gradient_aggregation.<locals>._need_aggregationc                    s   d}|j  |j| d }tt fddt| dj}t|dkr&d S | dj||j	|j
|jd |D ]\}}| | | | q7| dj|d d d d	d
 |gid id d S )Nz
@dy2staticz@GRADc                    s(   | d  kot fdd| d jD S )Nr   c                 3   s    | ]}| kV  qd S r   r$   )ro   Zout_argvar_grad_namer$   r%   	<genexpr>  s
    
z~PartialProgramLayer.prepare_gradient_aggregation.<locals>._insert_aggregation_ops_for_var.<locals>.<lambda>.<locals>.<genexpr>r   )anyoutput_arg_names)x)	start_idxr  r$   r%   <lambda>  s    zkPartialProgramLayer.prepare_gradient_aggregation.<locals>._insert_aggregation_ops_for_var.<locals>.<lambda>r   )r   r8   rl   shaperO   r   sumXZOut)r8   r   r   )	grad_namer   r;   filterr-   rY   r   r+   Z
create_varr8   rl   r  Z_rename_inputZ_rename_outputZ
_insert_op)target_programr5   suffixZnew_grad_nameZ
finded_opsr4   r  )r	  r  r%   _insert_aggregation_ops_for_var  s6   


zYPartialProgramLayer.prepare_gradient_aggregation.<locals>._insert_aggregation_ops_for_varN)r;   r  rn   r   global_blockr5   r   )	r!   r	  r   r  r  r  Zto_processed_varsZ_varZ
target_varr$   )r   r	  r%   prepare_gradient_aggregation  s   (z0PartialProgramLayer.prepare_gradient_aggregationc           	         s~  |j dd | jr| j  g }| j D ]}t|tjr)| 	 
|j qt djt| j  }|rt djt| j  }t| j@ t|dtjttfd tj|g d} fdd| jD } fd	d| jD } fd
d| jD }t||||| _W d    n1 sw   Y  | jr| j |\ }| |d |  |t| j  | jtj | <  S )NFr   r   targetszpaddle.static.gradients)r  r   c                    *   g | ]}t |tjr d |jqS r   r.   r   r/   rY   r5   r   ro   r5   r   r$   r%   rq         
z=PartialProgramLayer._append_backward_desc.<locals>.<listcomp>c                    s   g | ]}  d |jqS r  )rY   r5   r   )ro   paramr  r$   r%   rq     s    c                    r  r  r  r  r  r$   r%   rq     r  r   ) r   r   r_   rn   r   r.   r   r/   r2   r  r5   r   r+   rY   r   r   r   r   r;   tupler   Zcalc_gradient_helperrs   rt   r   r   rb   r  r   r'   r(   r   )	r!   r   r  outr	  Zgrad_info_mapZx_varsZ
param_varsr   r$   r  r%   r     s\   




z)PartialProgramLayer._append_backward_descc                 C   sh   g }| j D ])}d}|jD ]!}|jD ]}|j|jv s|j|jv r(|| d} nq|r- nqq|| _ dS )a'  
        Prune the parameters not used anywhere in the program.
        The `@to_static` may only decorated a sub function which
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        FTN)rt   blocksr   r   r   r  r2   )r!   r   Zrequired_paramsr  Zfound_paramrY   r  r$   r$   r%   r     s    




z(PartialProgramLayer._prune_unused_paramsc                 C   sf   t  r/t|D ])\}}|j}| j |r.| j |jtj	kr.|
d||< ||| _qd S d S )Nri   )r   r-   r   r   r  Zhas_varr5   rl   r'   ri   r   )r!   r   ir5   r   r$   r$   r%   r   $  s   
z+PartialProgramLayer._cast_fp16_if_pure_fp16c                 C   s   d| j jdd| jjdd| j d| jg}| jr4|d| jdg d| jd	g d
| jdg f | j	rB|d| j	d| j
f d}t|| }t pQt }| jdk}| jj}|sa|sa|rcd}|rgd}|d|g |S )NZforward_global_blockr   Zbackward_global_blockZis_testr   Zparam_grad_namesr  Zout_grad_namesr  Zx_grad_namesr  Zcuda_graph_capture_modeZcuda_graph_pool_idr   ZCINNFin_pir_pt_mode)r^   rX   rY   r   r   r   r   r   rv   r}   r~   r   r   Z_is_fwd_prim_enabledZ_is_bwd_prim_enabledr   rw   Zbuild_cinn_pass)r!   r   r   Zpir_dy2st_flagr!  Zis_prim_enabledZin_cinn_backendZis_cinn_enabledr$   r$   r%   r   0  sN   
	
z'PartialProgramLayer._prepare_attributesc                 C   s,   |  |}t|d|| j|}| |d  |S r   )_parse_skip_gc_varsadd_build_strategy_forrw   _apply_inplace_pass)r!   rd   r   forward_skip_varsZbuilded_infer_programr$   r$   r%   r   c  s   
z(PartialProgramLayer._build_infer_programc           	      C   s   |t | jj }|jd }| || jdg  }t	|||| j
|}| ||}t	|d|| j
|}| || t| |tjt t| t| |tjt|t| ||gS )Nr   r  )r+   rn   r3   rX   rY   rT   r"  r   rv   r#  rw   r$  r   r   ZForwardr6   ZBackward)	r!   ra   r   Zbackward_start_op_indexZbackward_end_op_indexZbackward_skip_varsZbackward_builded_programr%  Zforward_builded_programr$   r$   r%   r   p  sX   z6PartialProgramLayer._get_forward_backward_program_formc           
      C   s   dddd}t j }t rdnd}| ||}| |}tdd p)tdd }|r<||dd}	|s<t||d|	| |rP||dd}	|sRt||d|	| d S d S d S )	Nboolz	list[str])use_cudaZmem_opt_skip_varsZfor_partial_blockTFr   r   Zbuffer_shared_inplace_pass)r'   r   r   r   Zis_compiled_with_cudar"  r   r   )
r!   r^   r   Z
attr_typesZempty_startup_programr'  Zforward_mem_opt_skip_varsZbackward_mem_opt_skip_varsr!  r   r$   r$   r%   r$    sZ   

z'PartialProgramLayer._apply_inplace_passc                 C   s`   g }| j D ]}t|tjjjr||j  q| j	D ]}t|tjjjr-||j  q|S )zK
        Returns Variable Names from self._inputs and self.outputs
        )
rs   r.   r'   rz   r   r/   r2   rX   r   rn   )r!   Z	var_namesr5   r$   r$   r%   _inout_var_names  s   

z$PartialProgramLayer._inout_var_namesc                 C   sX   t | j}| j D ]\}}|jr|| q|r*t|j	dD ]}|| q"|S )z
        Parse variables that need to skip GC after execute it.
        If specify backward_program, it will keep the variables used in backward.
        T)
r   r(  r  varsitemsZis_datar2   r   Z#parse_safe_eager_deletion_skip_varsrX   )r!   r   r   	skip_varsvar_namer5   r$   r$   r%   r"    s   

z'PartialProgramLayer._parse_skip_gc_varsc           	      C   s   t |ttfs	J tj|}g }g }t }t|D ]L\}}t |t	j
r9d}tjj|| j| j d|dd}nt |tjjrV|jrS|j|sS||d}d|_n|}nq|| j| j  || q||fS )z1
        Prepare inputs, outputs, attrs.
        NFT)valuer   ZpersistableplaceZ	zero_copy)r.   r  r;   r'   r(   r)   r   _current_expected_placer-   r   Zndarrayr   r0   r1   rs   rX   r   stop_gradientr.  Z_equalsZ_copy_tor2   )	r!   r   Zflatten_inputsZ
input_varsZinput_var_namesZexpected_placer   r-  r5   r$   r$   r%   r     s4   z#PartialProgramLayer._prepare_inputsc                 C   s   t jj| jS r   )r'   r   r   Z#create_empty_tensors_with_var_descsr   r*   r$   r$   r%   r   !  s   z$PartialProgramLayer._prepare_outputsc                 C   s   | j ||d}|gS )Nr   )r   )r!   r   r   Zinner_scoper$   r$   r%   r   &  s   z%PartialProgramLayer._create_scope_vecc                 C   s*   t jt jjjg dt jjjd}d|_|S )NZ
cuda_graphT)r   r0   r1   r   r   ZFP32ZRAWr0  r!   r5   r$   r$   r%   r{   ,  s   z*PartialProgramLayer._create_cuda_graph_vecc                    s2    fdd}t  jj|D ]	\}}||| qd S )Nc                    s&    j |  }t|tjsJ |j|_d S r   )rn   r.   r   r/   r0  )rp   Zeager_tensorr5   r*   r$   r%   set_stop_gradient9  s   
zDPartialProgramLayer._update_stop_gradient.<locals>.set_stop_gradient)ziprn   r3   )r!   r   r2  r4   r5   r$   r*   r%   _update_stop_gradient7  s   z)PartialProgramLayer._update_stop_gradientc                 C   sX   | j  }t| j jD ]
\}}|| ||< q| j |}|dur*t|dkr*|d }|S )zZ
        Restores same nested outputs by only replacing the Variable with Tensor.
        Nr   r   )rn   r   r-   r3   r,   r+   )r!   r   Zflatten_outputsr   r4   Zoutsr$   r$   r%   r   A  s   
z PartialProgramLayer._restore_outc                 C   s   |j ddS )NTr   )r   r   r$   r$   r%   _clone_for_testO  s   z#PartialProgramLayer._clone_for_testc                 C   s2   t |tjjr|jdgkr| d tkrdS dS )Nr   r   TF)r.   r   r0   r1   r  numpyr   r1  r$   r$   r%   _is_no_valueS  s   z PartialProgramLayer._is_no_valuec                    s   t |tjjr |rdS |S t |ttfrQt |tr(t fdd|D }n	 fdd|D }t|t|k}t|dkrC|rCdS t|dkrO|rO|d S |S |S )zK
        Removes invalid value for various-length return statement
        Nc                 3   s    | ]
}  |s|V  qd S r   r7  r  r*   r$   r%   r  d  s    

z7PartialProgramLayer._remove_no_value.<locals>.<genexpr>c                    s   g | ]	}  |s|qS r$   r8  r  r*   r$   r%   rq   i  s    z8PartialProgramLayer._remove_no_value.<locals>.<listcomp>r   r   )r.   r   r0   r1   r7  r  r;   r+   )r!   r   resZhas_removedr$   r*   r%   r   Z  s"   


z$PartialProgramLayer._remove_no_valuec                 C   sJ   |D ] }|j t  }|jd| }|d u rq||  qd S r   )	r   r   Zgrad_var_suffixrX   rY   Zfind_varencoder   r8   )r!   paramsr   r  r  Zgrad_varr$   r$   r%   r   v  s   z"PartialProgramLayer._set_grad_typec                 C   s   t | jttfstdt| j t }t| jD ]\}}t |tj	j
s.td|t|||j q|jD ]}|j D ]\}}t |tjrS||vrStd| q?q8dS )a  
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.EagerParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with EagerParamBase of original dygraph.
        zUType of self._params in PartialProgramLayer should be list or tuple, but received %s.zaType of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.aq  
	We don't support to define layer with parameters in the function decorated by `@to_static`.
	But we found parameter(%s) was created in the decorated function.

	Revise suggestion: 
		1. Please ensure all your sublayers are inheritted from nn.Layer.
		2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as ListN)r.   rt   r;   r  	TypeErrorr8   r6   r-   r   r0   r1   r:   r7   r   r  r)  r*  r   	Parameter
ValueError)r!   r   Zparam_and_buffer_names_setr   r5   rY   r   r$   r$   r%   r     s6   
z,PartialProgramLayer._check_params_all_initedc                 C   s   |r|S d S r   r$   )r!   r)  r$   r$   r%   r     s   zPartialProgramLayer._valid_varsr   )NFr?   )Cr@   rA   rB   rC   r&   r[   r   r   r   r   r   r   r   r   r   r   r   r   rF   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rD   r   r   r   rd   r^   r   rx   r  r   r   r   r   r   r   r$  r(  r"  r   r   r   r{   r4  r   r5  r7  r   r   r   r   __classcell__r$   r$   r   r%   rf      s    ;" 

	
	
	














	

	




P
6
3

72

'


&rf   Fc                 C   s<   | j }|r|r|dd  }t| j|| j| j| jfi | jS r   )r   rf   r   r   r   r   r   )Zconcrete_programZfrom_methodr   r$   r$   r%   partial_program_from  s   r@  c                 C   s   ||k r=t jjt| j|||d}|r|jdt| |t	 t
  t
|j}| }t|jdr<|jj|_nt j }| dj D ]}|d|d qJt| j|jD ]\}	}
|
j|	jj q]|S )N)rg   Zskip_gc_varsr   r   F)r'   r   ZCompiledProgramr   ZGraphrX   Z_graphr6   _compiler   r   r/  ZIrGraphZ
to_programr   Z_programr   r   rY   r)  valuesZ_clone_variabler3  r  Zset_parent_idxparent)r   Zstart_op_indexZend_op_indexrg   r+  Zcompiled_programZir_graphZbuilded_programr5   origincurrentr$   r$   r%   r#    s,   
r#  r?   r   ).r   copyr   r6  r   r'   r   Zpaddle.amp.auto_castr   r   Zpaddle.baser   r   r   r	   Zpaddle.base.compilerr
   Zpaddle.base.data_feederr   r   Zpaddle.base.dygraph.baser   Zpaddle.base.frameworkr   r   Zpaddle.base.unique_namer   Zpaddle.optimizer.lrr   rh   r   Zexport_subgraphr   r   r(   r   r   r   __all__r   rF   rN   r\   rf   r@  r#  r$   r$   r$   r%   <module>   sD   ;        
