o
    "j                     @   s   d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	m
Z
 ejjZG dd dZG dd	 d	eZG d
d deZG dd deZG dd deZdS )    N)unique_name)wait_server_ready)core)default_main_programdefault_startup_programc                   @   sd   e Zd ZdZdd Zdd Zdd Zdd	 Z	
dddZdd Z	dd Z
dd Zdd Zdd ZdS )
Collective c                 C   sN   || _ d | _d | _d | _d | _d | _d | _d | _tj	}|
 | _| | _d S N)nrings	endpointscurrent_endpointother_endpointsnranksrankstartup_programmain_programr   op_proto_and_checker_makerZkOpRoleAttrNameop_role_keyZkOpRoleVarAttrNameop_role_var_key)selfr
   Zop_maker r   i/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/paddle/distributed/transpiler/collective.py__init__   s   
zCollective.__init__c           	      C   s  t |tr
|d}|| _|d u rt | _|| _|d u r t | _t|| _| jdkr8| j	dkr8| j	dkr8t
d|dk r@t
d|| _||vrOt
d|t||| _|| _|rit|}|d d  }|| || _|| _| j | j_|   | j | j_|   d S )	N,   single_process_multi_threadboxz the number of endpoints must > 1r   zrank must >= 0z current endpoint %s is not in %s)
isinstancestrsplitr   r   r   r   lenr   mode
ValueErrorr   r   r   remover   	wait_portcloneZ_origin_program_transpile_startup_program_transpile_main_program)	r   r   r   r   r   r   r$   r   r   r   r   r   	transpile,   sD   







zCollective.transpilec                 C   s   t d)Nz'call the inherited method of subclasses)NotImplementedErrorr   r   r   r   r'   e   s   z"Collective._transpile_main_programc              	   C   s:   t | jD ]}| | j| j| j| j|| j q|   d S r	   )	ranger
   _init_communicatorr   r   r   r   r$   _broadcast_params)r   ring_idr   r   r   r&   h   s   z%Collective._transpile_startup_programFc                 C   s  d |}t|}	|d d  }
|
| | }|dkr"|r"t|
 | }t r}|jt	ddtj
jjd}|jdi d|id|d	|d
|
| jtjid |sf|jdd|ii d|	d|d|| jtjid d S |jdd|ii d|	d|d|| jtjid d S t r|jt	ddtj
jjd}|jdi d|id|d	|d
|
| jtjid |jdd|ii d|	d|d|d|| jtjid d S tj jtj v r|jt	ddtj
jjd}|jdi d|id|d	|d
|
| jtjid |jdd|ii d|	d|d|d|| jtjid d S d S )Nr   r   Znccl_idT)namepersistabletypeZc_gen_nccl_idOutr   Zendpointr   r1   inputsoutputsattrsZc_comm_initXr   r.   Zc_comm_init_multitrainerZ	ntrainersZ
trainer_idZbkcl_idZc_gen_bkcl_idr   Zxccl_idZc_gen_xccl_id)joinr    r#   global_blockr   r   Zis_compiled_with_cuda
create_varr   generateVarDescVarTypeZRAW	append_opr   OpRoleForwardZis_compiled_with_xpupaddledistributedZParallelEnvZdevice_typeZdeviceZget_all_custom_device_type)r   programr   r   r   r.   r$   Zhas_multitrainerZendpoints_strr   r   blockZnccl_id_varZbkcl_id_varZxccl_id_varr   r   r   r,   t   s   







zCollective._init_communicatorc                 C   s   | j  }d}| D ]"}|jrq|d | j }|jdd|id|id|dd| jtjid	 qt	| jD ]}|jd
d|id|id|| jtjid	 q3d S )Nr   Zc_broadcastr7   r2   r.   rootr   r3   c_sync_comm_stream)
r   r9   iter_parametersis_distributedr
   r>   r   r?   r@   r+   )r   rD   r.   paramr   r   r   r-      s.   
zCollective._broadcast_paramsc                 C   s>   | j |jvrdS t| | j  }|ttj@ o|ttj@ S )NF)r   
attr_namesint	all_attrsr?   BackwardZLoss)r   opZop_roler   r   r   _is_loss_grad_op  s   zCollective._is_loss_grad_opc                 C   (   | j |jv ot| | j  ttj@ S r	   )r   rK   rL   rM   r?   rN   r   rO   r   r   r   _is_backward_op  
   zCollective._is_backward_opc                 C   s   d|j v od|j v od|j v S )NParamGradLearningRate)Zinput_namesrR   r   r   r   _is_update_op  s
   
zCollective._is_update_opc                 C   rQ   r	   )r   rK   rL   rM   r?   OptimizerR   r   r   r   _is_optimizer_op  rT   zCollective._is_optimizer_opN)F)__name__
__module____qualname____doc__r   r(   r'   r&   r,   r-   rP   rS   rX   rZ   r   r   r   r   r      s    9
|r   c                   @   2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )GradAllReducer      c                 C   s   t | | d| _d S )NZgrad_allreduce)r   r   r!   r   r
   r   r   r   r   &  s   
zGradAllReduce.__init__c                 C   s   |    |   d S r	   )_insert_scale_loss_grad_ops_insert_allreduce_opsr*   r   r   r   r'   *  s   z%GradAllReduce._transpile_main_programc              
   C   sv   | j  }ttt|jD ]*\}}| |r8|j|jd  }|j	|d dd|id|idd| j
 | jtjid qdS )	
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        r   r   scaler7   r2         ?r3   N)r   r9   reversedlist	enumerateopsrP   varsoutput_arg_names
_insert_opr   r   r?   rN   )r   rD   idxrO   loss_grad_varr   r   r   rc   .  s   


z)GradAllReduce._insert_scale_loss_grad_opsc           
      C   s  | j  }d}d }ttt|jD ]\}}| |r| j|jv r|	 | j }t
|dkr/qt
|d dks9J |}tdt
|dD ]O}|j||  }	|j||d   }|	jrYqC||krw|d7 }|j|dd|id|i| jtjid |d7 }|d | j }|j|d	d|id|id
|| jtjid qCq|d u rd S t|jD ]*\}}| |rt| jD ]}|j|| dd|id|id
|| jtjid q d S qd S )NrE   r   ra   r   c_sync_calc_streamr7   r2   r3   c_allreduce_sumr.   rG   )r   r9   rh   ri   rj   rk   rS   r   rK   rM   r    r+   rl   rI   rn   r   r?   rN   r
   rZ   )
r   rD   r.   gradro   rO   op_role_varoffsetirJ   r   r   r   rd   B  sn   



z#GradAllReduce._insert_allreduce_opsNra   )r[   r\   r]   r^   r   r'   rc   rd   r   r   r   r   r`   #  s    
r`   c                   @   r_   )LocalSGDr   ra   c                 C   s   t | | d| _d| _d S )Nz	@SNAPSHOTZ	local_sgd)r   r   snapshot_keyr!   rb   r   r   r   r     s   
zLocalSGD.__init__c                 C   s   t |  | j }g }| D ]
}|js|| q|D ]#}|j| |j	|j
ddd}|jdd|gid|gi| jtjid qd S )NT)r/   shaper0   stop_gradientassignr7   r2   r3   )r   r&   r   r9   rH   rI   appendr:   snapshot_namer/   rz   r>   r   r?   r@   )r   rD   Znon_dist_paramsrJ   snapshotr   r   r   r&     s*   




z#LocalSGD._transpile_startup_programc                 C   s
   || j  S r	   )ry   )r   
param_namer   r   r   r~     s   
zLocalSGD.snapshot_namec           	   
   C   s  | j  }g }d}ttt|jD ]y\}}| |r|j|dd  }|j	r)q|j
| |j|jdd|jd}|j|d d|g|gdd	|gi| jtjid
 |j|d dd|id	|i| jtjid
 |d | j }|j|d dd|gid	|gid|| jtjid
 |||f qt| jD ]}|jdd|id	|id|| jtjid
 qt|D ]J}|d }|d }|jdd|gid	|gidd| j | jtjid
 |jd|g|gdd	|gi| jtjid
 |jdd|gid	|gi| jtjid
 qd S )NrE   rU   r   T)r/   rz   r0   r{   dtyper   Zelementwise_sub)r7   Yr2   r3   ra   rq   r7      rr   r.   rG   rf   rg   r|   )r   r9   rh   ri   rj   rk   rX   rl   inputrI   r:   r~   r/   rz   r   rn   r   r?   rY   r
   r}   r+   r>   r   )	r   rD   Zordered_param_snapshotr.   ro   rO   rJ   r   Zparam_snapshotr   r   r   r'     s   





	

z LocalSGD._transpile_main_programNrw   )r[   r\   r]   r^   r   r&   r~   r'   r   r   r   r   rx     s    
rx   c                   @   s@   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dS )SingleProcessMultiThreadz*
    single process multi thread mode
    c                 C   sR   t | d d| _ttdd| _ttdd| _ttdd	d| _
d S )	Nr   r   ZPADDLE_FUSE_ALLREDUCE1ZPADDLE_LOSS_SCALEFLAGS_selected_gpusz0,1,2,3,4,5,6,7r   )r`   r   r!   rL   osgetenvfuse_allreduce
loss_scaler    r   gpu_numsr*   r   r   r   r     s   
z!SingleProcessMultiThread.__init__c              
   C   s   d}t | jdkrt dd | jD }|dkrN|| _td td| j td| j td| j| jf  t| jD ]}| | j	| j| j| j|| j
d	 q9d S d| _td
 | j	 }|jdddid d S )Nr   r   c                 S   s   h | ]	}| d d qS ):r   )r   ).0xr   r   r   	<setcomp>  s    zFSingleProcessMultiThread._transpile_startup_program.<locals>.<setcomp>2begin to _transpile_startup_program for multi-nodecurrent_endpoint: total endpoints: rank: %d, ring_id: %dT3begin to _transpile_startup_program for single-nodec_comm_init_allr.   r1   r6   )r    r   r   printr   r   r
   r+   r,   r   r$   r9   r>   )r   Z	nodes_numr.   rD   r   r   r   r&     s0   
z3SingleProcessMultiThread._transpile_startup_programc                 C   sh   |   }| jdkr|dkrd S | jr| | |dkrd S | jdkr.td|  |   d S |   d S )Nr   z*begin used fuse_allreduce param count = %s)_get_update_param_countr   rc   r   r   _insert_fuse_allreduce_opsrd   )r   	param_cntr   r   r   r'     s   

z0SingleProcessMultiThread._transpile_main_programc                 C   s   d}| j  }ttt|jD ]C\}}| |sq| j|jvr!q|	 | j }t
|dkr/qt
|d dks9J tdt
|dD ]}|j||  }|jrNqA|d }qAq|S )z-
        get need update param count
        r   ra   r   )r   r9   rh   ri   rj   rk   rS   r   rK   rM   r    r+   rl   rI   )r   Zparam_countrD   ro   rO   rt   rv   rJ   r   r   r   r   2  s$   


z0SingleProcessMultiThread._get_update_param_countc              
   C   s   |dkrd| j  | j }nd| j }td|  | j }ttt|jD ](\}}| 	|s0q&|j
|jd  }|j|d dd|id|id|| jtjid q&d	S )
re   r   rg   z,begin _insert_scale_loss_grad_ops scale = %sr   rf   r7   r2   r3   N)r   r   r   r   r9   rh   ri   rj   rk   rP   rl   rm   rn   r   r?   rN   )r   r   rf   rD   ro   rO   rp   r   r   r   rc   J  s"   


z4SingleProcessMultiThread._insert_scale_loss_grad_opsc              	   C   s  | j  }d}d}g }d}ttt|jD ]X\}}| |rn| j|jv rn|	 | j }t
|dkr3qt
|d dks=J |}	tdt
|dD ]&}
|j||
  }|j||
d   }|jr]qG|	|krm|| t||	d }qGq|du rudS | jdkr|j|dd|d id|d i| jtjid	 |d7 }|d | j }|j|d
d|id|id|| jtjid	 |d7 }|j|dd|d id|d id|| jtjid	 |d7 }dS |}|jddgdtjjjdd}ddtjjjd}|j|dd|i||d|d	 |d7 }|j|dd|id|i| jtjid	 |d7 }|d | j }|j|dd|id|id|| jtjid	 |d7 }|j|dd|id|id|| jtjid	 |d7 }dS );
        insert coalesce_tensor and all reduce ops
        rE   Nr   ra   r   rq   r7   r2   r3   Zc_allreduce_xsumr.   rG   fused_outputFTr/   rz   r0   r   r{   )	copy_dataZset_constantr   coalesce_tensorInputOutputZFusedOutputrr   )r   r9   rh   ri   rj   rk   rS   r   rK   rM   r    r+   rl   rI   r}   maxr   rn   r   r?   rN   r
   r:   r   r<   r=   FP32)r   rD   r.   rs   Zinput_gradsZglobal_offsetro   rO   rt   ru   rv   rJ   Zoutput_gradsr   Zcoalesce_tensor_attrsr   r   r   r   a  s   







	
z3SingleProcessMultiThread._insert_fuse_allreduce_opsN)
r[   r\   r]   r^   r   r&   r'   r   rc   r   r   r   r   r   r     s    	r   c                   @   sB   e Zd ZdZdddZdd Zdd	 Zd
d Zdd Zdd Z	dS )MultiThreadr   r   
all_reducec                 C   s>   t | | d| _|| _d| _tddd}t|| _	d S )Nr      r   z0,1,2,3,4,5,6,7,8r   )
r`   r   r!   
trans_modefuse_grad_size_in_numr   r   r   r    gpu_num)r   r
   r   r   r   r   r   r     s   zMultiThread.__init__c              
   C   s   t | jdkr;td td| j td| j td| j| jf  t| jD ]}| | j| j| j| j|| j	d q&d S d| j
v r`td | j }|jd	ttttd
dddd d S td | j }|jd	ddid d S )Nr   r   r   r   r   TZxpuz:begin to _transpile_startup_program for single-node in XPUr   r   r   r   )Zdevicesr.   r   r   r.   )r    r   r   r   r   r
   r+   r,   r   r$   r   r9   r>   ri   maprL   r   r   r   )r   r.   rD   r   r   r   r&     sD   



z&MultiThread._transpile_startup_programc                 C   s   |    | jdkrtd | j| j | _|   |   d S | jdkr-td |   d S | jdkrDt	t
dddkrDtd	 d S td
 |   d S )NZ
all_gatherz%begin to transpile in all-gather modeZfuse_all_reducez*begin to transpile in fuse all-reduce modeZall_reduce_xpur   r   r   zHskip transpile in all-reduce-xpu mode when number of devices is only onez%begin to transpile in all-reduce mode)rc   r   r   r   r   allgather_ranks_insert_allgather_ops_update_adam_opsr   r    r   r   r   rd   r*   r   r   r   r'     s    


z#MultiThread._transpile_main_programc                 C   s  | j  }d}d}ttt|jD ]\}}| |r| j|jv r|	 | j }t
|dkr/qt
|d dks9J |}tdt
|dD ]j}|j||  }	|j|| d | jgt|	j dtjjjdd}
|j||d	   }|	jrqqC||kr|d	7 }|j|d
d|id|i| jtjid |d	7 }|d	 | j }|j|dd|id|
id| jd|| jtjid qCq|du rdS t|jD ]*\}}| |rt| jD ]}|j|| dd|id|id|| jtjid q dS qdS )z9
        insert allgather op to the main_program
        rE   Nr   ra   
_allgatherFTr   r   rq   r7   r2   r3   Zc_allgatherr   r.   rG   )r   r9   rh   ri   rj   rk   rS   r   rK   rM   r    r+   rl   r:   r   rz   r   r<   r=   r   rI   rn   r   r?   rN   r
   rZ   )r   rD   r.   rs   ro   rO   rt   ru   rv   rJ   Znew_grad_varr   r   r   r     s~   




z!MultiThread._insert_allgather_opsc              
      s  | j   ttt jD ]\}| r|}jdkr$jdkr$qdd  j	dd   j	dd   j	dd   j	dd   j	dd   j	d	d  d
} j	
dd   j	
dd   j	
dd   j	
dd   j	
dd  d}dddddd} fddt| jD } j|dd j	dd d  id|i| jddd |d7 }t| jD ]}|| |d <  j|j|||d |d7 }qވ | qd!S )"zC
        remove the original adam op, and add new adam ops
        ZadamZlambrU   r   rW   Moment1Moment2Beta1PowBeta2Pow)rU   rW   r   r   r   r   ParamOut
Moment1Out
Moment2OutBeta1PowOutBeta2PowOut)r   r   r   r   r   epsilonbeta1beta2	lazy_modemin_row_size_to_use_multithread)r   r   r   r   r   c              	      sD   g | ]} j d  t|  jdd  jdtjjjddqS )_rU   r   FTr   )	r:   r   rl   r   rz   r   r<   r=   r   )r   rv   rD   rO   r   r   r   
<listcomp>  s    z0MultiThread._update_adam_ops.<locals>.<listcomp>r   r7   r   r2   )numZaxisr3   r   rV   N)r   r9   rh   ri   rj   rk   rZ   r1   r   rl   outputattrr+   r   rn   Z
_remove_op)r   ro   ru   r4   r5   r6   Z
split_varsrv   r   r   r   r   f  sj   

		

	

zMultiThread._update_adam_opsc                 C   s  | j  }d| j }d}g }t|jD ]O}| |rb| j|jv rb| | j }t	|dkr.qt	|d dks:J dt
dt	|dD ]}|| }||}	||d  }
||
}|	jr\qB|| qBq|du ridS g }d}|D ]'}t	|dkst	|d | jks|j|kr||g |j}qo|d | qog }t|jD ]F\}}| |r|D ]8}|jtd|d j |d jdd	d
}|| |j|dd|i||ddd	dd	d|d j| jtjid q nqt|jD ]9\}}| |r#|D ]*}|j|dd|id|id|dd| jtjid |j|dd|id|i| jtjid q nqt	|dkr1|  dS t|jD ]%\}}| |rZ|j|dd|d id|d id|| jtjid  nq6|  dS )r   r   Nra   zRvars need to be one param var followed by one grad var, but got odd number of varsr   rE   ZFusedOutput_FT)r/   r   r0   r{   r   r   r   r   Z	use_alignr   r3   rr   r7   r2   r.   Zuse_calc_streamrq   rG   )r   r9   r
   rh   rk   rS   r   rK   rM   r    r+   varrI   r}   r   r   rj   rZ   r:   r   r;   r/   rn   r   r?   rN   Z_sync_with_cpp)r   rD   r.   rs   Zparam_gradsrO   rt   rv   r   rJ   Z	grad_namesegmentsZ
last_dtyper   Z
fused_varsro   segmentZtmp_varZ	fused_varr   r   r   r     s   











z&MultiThread._insert_fuse_allreduce_opsN)r   r   )
r[   r\   r]   r^   r   r&   r'   r   r   r   r   r   r   r   r     s    

'KHr   )r   rA   Zpaddle.baser   Z5paddle.distributed.fleet.base.private_helper_functionr   Zpaddle.frameworkr   Zpaddle.staticr   r   r   r?   r   r`   rx   r   r   r   r   r   r   <module>   s     	ar ]