o
    "j                     @   s   d Z ddlZddlZddlZddlZddlmZmZ ddlZ	ddl
Z
ddlmZ ddlmZ ddlmZ g ZG dd	 d	ZG d
d dZG dd dZG dd deZG dd deZdS )zDefinition of Role Makers.    N)ManagerProcess)core)wait_server_ready   )getenv_or_backupc                   @   s    e Zd ZdZdZdZdZdZdS )Role      r         N)__name__
__module____qualname__WORKERSERVERHETER_WORKERALLCOORDINATOR r   r   i/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/paddle/distributed/fleet/base/role_maker.pyr   "   s    r   c                   @   s   e Zd ZdZG dd dZdd Z		d dd	Zd
d Zdd Zdd Z	dd Z
dd Zdd Zdd Zdd Zd!ddZd"ddZdS )#GloozL
    Gloo is a universal class for barrier and collective communication
    c                   @   s   e Zd ZdZdZdZdS )zGloo.RENDEZVOUSr	   r
   r   N)r   r   r   HDFSFILEHTTPr   r   r   r   
RENDEZVOUS/   s    r   c                 C   sv   d | _ d | _d | _g d| _d| _d| _d| j | _d| _d| _d| _	d | _
d | _d | _d| _d| _d| _d| _d S )	N)workerserverallz?gloo is not initialized, will not communicator with other nodesz.gloo initialized error, please check argumentsz#argument error, comm_world must in Fi  i )_worker_comm_server_comm_nodes_comm_comm_world	_err_init	_err_type
_err_world_is_initialized_init_timeout_seconds_run_timeout_seconds_rendezvous_role_iface_role_id_worker_num_server_num_need_init_allselfr   r   r   __init__4   s&   


zGloo.__init__FNc                 C   sN  || _ || _|| _|| _|| _|| _d| _|dd| _d }| j t	j
jkrL|dd}	|dd}
|dd}|	r=|
r=|sBt| j| |	|
|| j nS| j t	j
jkrh|dd}|s`t| j| || j n7| j t	j
jkr|dd}|dd}|dd	}|d
}|r|st| j| ||| j||}nt| jd| _|| _d S )N store.prefixdfs.namedfs.ugidfs.path	http.host	http.portstart_http_serverFhttp_server_dT)r*   r+   r-   r.   r/   r0   r,   get_prefixr   r   r   
ValueErrorr%   	_init_dfsr   _init_fsr   
_init_httpr'   _http_server)r2   
rendezvousrolerole_id
worker_num
server_numneed_init_allkwargshttp_serverdfs_namedfs_ugidfs_pathfs_pathipportr;   r<   r   r   r   initO   sD   






z	Gloo.initc                    s    fdd}j tjkr tj\}}|||d}|_ntj\}}|||d}|_jrGtj\}}|||d}|_	d S d S )Nc                    sf   t  }||  || | |j |jj	 |
tj |dd |  |S )Nr4   r   r   set_rankset_size
set_prefix	set_ifacer,   set_timeout_secondsr(   r)   Zset_hdfs_storeospathjoinrR   ranknodesrE   gloorO   prefixr2   r   r   rR         


zGloo._init_fs.<locals>.initr   r   r   
r+   r   r   _get_rank_nodesr    r   r!   r0   r   r"   )r2   rO   ra   rR   r]   r^   r_   r   r`   r   rA      s   
zGloo._init_fsc           	         s    fdd}j tjkr"tj\}}|||d}|_ntj\}}|||d}|_jrItj\}}|||d}|_	d S d S )Nc                    sf   t  }||  || | |j |jj	 |
tj|  |  |S NrS   r\   rL   rN   rM   ra   r2   r   r   rR      rb   zGloo._init_dfs.<locals>.initr   r   r   rc   )	r2   rL   rM   rN   ra   rR   r]   r^   r_   r   rf   r   r@      s   
zGloo._init_dfsc                    s   fdd  fdd}fdd}t |r%td ||}jtjkr<tj\}	}
||	|
d}|_|rHd	|d
< |  d S d S )Nc                    st   t d  d|  ddlm} | |}|  d}| dds%| s4t| | dds%| r%|  d S )Nzstart http_server: z, r   )KVServerr   runningF)	printZ*paddle.distributed.fleet.utils.http_serverrg   startr=   Zshould_stoptimesleepstop)r<   size_drg   rK   Zwait_seconds)rQ   r   r   Z__start_kv_server   s   



z*Gloo._init_http.<locals>.__start_kv_serverc                    sT   d d }|j i}td| d|  d| d< t | |fd}d|_|  |S )N_r   zworker_key:z, size: Trh   )targetargs)r.   ri   r   daemonrj   )r<   Z
worker_keyrn   rC   )_Gloo__start_kv_serverra   r2   r   r   init_kv_server   s   z'Gloo._init_http.<locals>.init_kv_serverc                    sx   t  }||  || | |j |jj	 |
 d d tg}t|g |  |S )Nr   :)r   r   rT   rU   rV   rW   r,   rX   r(   r)   Zset_http_storer[   strr   rR   )r]   r^   rE   r_   ep)rP   rQ   ra   r2   r   r   rR      s   



zGloo._init_http.<locals>.initzto start http_serverr   Frh   )intri   r+   r   r   rd   r    r[   )r2   rP   rQ   ra   r;   r<   rt   rR   rK   r]   r^   r_   r   )rs   rP   rQ   ra   r2   r   rB      s   zGloo._init_httpc                 C   s   d}d}|t jkr| j}| j}||fS |t jkr"| j}| j}||fS |t jkrD| j| j }| jt jkr:| j}||fS | j| j }||fS t| j	 ||fS )Nr   r   )
r   r   r.   r-   r   r/   r   r+   r?   r%   )r2   rE   r^   r]   r   r   r   rd      s&   



zGloo._get_rank_nodesc                 C   s    |   }|  }|dkr|S |S )0
        get default physical interface
        lo)%_Gloo__get_default_iface_from_gateway(_Gloo__get_default_iface_from_interfaces)r2   Zdefault1Zdefault2r   r   r   Z__get_default_iface  s   zGloo.__get_default_ifacec                 C   s   t d  d}d}d}|D ]C}| }d|v r+d|v r+|d}|d}q|durU|durUd}t||kr?|| }|rU|dkrU|dkrUt||krU||   S qdS )	ry   zroute -A inet
NZGatewayZIface*z0.0.0.0rz   )rY   popenreadstripsplitindexlen)r2   resZgateway_idxZ	iface_idxitemZgatewayr   r   r   Z __get_default_iface_from_gateway  s(   
z%Gloo.__get_default_iface_from_gatewayc                 C   sD   t d  d}|D ]}d|v r|dd    S qdS )ry   zip -f inet addr | awk NR%3==1r}   Z	BROADCASTru   r	   rz   )rY   r   r   r   r   )r2   r   r   r   r   r   Z#__get_default_iface_from_interfaces8  s   z(Gloo.__get_default_iface_from_interfacesc                 C   sd   | j st| j dS || jvrt| j|dkr | j  dS |dkr+| j	  dS | j
  dS )z+
        dummy barrier, do nothing
        Nr   r   )r'   warningswarnr$   r#   r?   r&   r    barrierr!   r"   r2   
comm_worldr   r   r   r   D  s   

zGloo.barriersumr   c                 C   s   | j st| j |S || jvrt| jt|}|j	}|
d }| | |dkr5| j||}n|dkrA| j||}n| j||}t|
|}|S )Nr   r   r   )r'   r   r   r$   r#   r?   r&   nparrayshapeZreshapetolistr   r    
all_reducer!   r"   )r2   inputmoder   Zinput_shapeZ
input_listZansoutputr   r   r   r   V  s    



zGloo.all_reducec                 C   sj   | j st| j |S || jvrt| j|dkr!| j|}|S |dkr-| j	|}|S | j
|}|S )zg
        dummy all gather, do nothing
        Args:
            obj(any): obj to do all gather
        r   r   )r'   r   r   r$   r#   r?   r&   r    
all_gatherr!   r"   )r2   r   r   r   r   r   r   r   n  s   

zGloo.all_gather)FNr   r   r   )r   r   r   __doc__r   r3   rR   rA   r@   rB   rd   Z_Gloo__get_default_ifacer{   r|   r   r   r   r   r   r   r   r   *   s"    "
5D
r   c                   @   s   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zd%ddZd&d d!Zd"d# Zd$S )'RoleMakerBasez
    RoleMakerBase is a base class for assigning a role to current process
    in distributed training.
    A paddle developer can implement RoleMakerBase to design a role maker
    for worker or pserver assignment.
    c                 C   s(   g | _ g | _d| _d| _d | _d| _d S )Nr4   Fr   )_worker_endpoints_server_endpoints_cur_endpoint_role_is_generatedr+   _current_idr1   r   r   r   r3     s   
zRoleMakerBase.__init__c                 C      t d)z7
        return is_worker() of current process
        +Please implement this method in child classNotImplementedErrorr1   r   r   r   
_is_worker     zRoleMakerBase._is_workerc                 C   r   )z7
        return is_server() of current process
        r   r   r1   r   r   r   
_is_server  r   zRoleMakerBase._is_serverc                 C   r   )z
        Check whether the node is the first instance of worker.
        Returns:
            bool: True if this is the first node of worker,
                  False if not.
        r   r   r1   r   r   r   _is_first_worker     zRoleMakerBase._is_first_workerc                 C   r   )zc
        Get current total worker number.

        Returns:
            int: worker number
        r   r   r1   r   r   r   r.     r   zRoleMakerBase._worker_numc                 C   r   )zc
        Get current total server number.

        Returns:
            int: server number
        r   r   r1   r   r   r   r/     r   zRoleMakerBase._server_numc                 C   r   )zS
        Get current worker id.

        Returns:
            int: node id
        r   r   r1   r   r   r   _worker_index  r   zRoleMakerBase._worker_indexc                 C   r   )zS
        Get current server id.

        Returns:
            int: node id
        r   r   r1   r   r   r   _server_index  r   zRoleMakerBase._server_indexc                 C   r   )zL
        Get current id.

        Returns:
            int: node id
        r   r   r1   r   r   r   r-     r   zRoleMakerBase._role_idc                 C   r   )zY
        Get the training node number
        Returns:
            int: node num
        r   r   r1   r   r   r   	_node_num  s   zRoleMakerBase._node_numc                 C      | j S )z*
        return trainer endpoints
        )r   r1   r   r   r   _get_trainer_endpoints     z$RoleMakerBase._get_trainer_endpointsc                 C   r   )z*
        return pserver endpoints
        )r   r1   r   r   r   _get_pserver_endpoints  r   z$RoleMakerBase._get_pserver_endpointsc                 C   s   d | j| j| j| jS )NzDrole: {}, current_id: {}, worker_endpoints: {}, server_endpoints: {})formatr+   r   r   r   r1   r   r   r   	to_string  s   zRoleMakerBase.to_stringr   c                 C   s   t d d S )Nz7warning: RoleMakerBase does not have all gather worker.ri   r2   r   r   r   r   r   _all_gather  s   zRoleMakerBase._all_gatherr   c                 C      t d dS )z
        Args:
            input(list/numpy.array): array of one dim
            output(list/numpy.array): array of one dim
            mode(str): "sum" or "min" or "max"
        z7warning: RoleMakerBase does not have all reduce worker.Nr   r2   r   r   r   r   r   r   _all_reduce  s   zRoleMakerBase._all_reducec                 C   r   )zE
        barrier between trainers if current role is TRAINER
        z4warning: RoleMakerBase does not have barrier worker.Nr   r   r   r   r   _barrier  s   zRoleMakerBase._barrierNr   r   )r   r   r   r   r3   r   r   r   r.   r/   r   r   r-   r   r   r   r   r   r   r   r   r   r   r   r     s$    						

	r   c                       sN  e Zd ZdZdQ fdd	Zdd ZdRdd	ZdSddZdd Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)d* Zd+d, Zd-d. Zd/d0 Zd1d2 Zd3d4 Zd5d6 Zd7d8 Zd9d: Zd;d< Zd=d> Z d?d@ Z!dAdB Z"dCdD Z#dEdF Z$dGdH Z%dIdJ Z&dKdL Z'dMdN Z(dOdP Z)  Z*S )TPaddleCloudRoleMakera  
    PaddleCloudRoleMaker is an interface for distributed configuration initialization based on obtaining distributed related information from environment variables.

    Examples:
        .. code-block:: python

            >>> import os
            >>> import paddle.distributed.fleet as fleet

            >>> os.environ["PADDLE_PSERVER_NUMS"] = "2"
            >>> os.environ["PADDLE_TRAINERS_NUM"] = "2"

            >>> os.environ["POD_IP"] = "127.0.0.1"
            >>> os.environ["PADDLE_PORT"] = "36001"
            >>> os.environ["TRAINING_ROLE"] = "PSERVER"
            >>> os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"

            >>> os.environ["PADDLE_TRAINER_ID"] = "0"

            >>> fleet.PaddleCloudRoleMaker(is_collective=False)

    Fc                    s|   t    || _d| _|| _d| _d| _d| _g | _g | _	g | _
d| _d| _d| _g | _g | _g | _d | _d| _t | _d S )NFr	   cpu)superr3   _is_collective_non_distributed_kwargsr   	_stage_id
_stage_num_next_heter_trainer_endpoints!_previous_heter_trainer_endpoints_heter_trainer_endpoints_heter_trainer_device_heter_trainer_device_type_is_heter_parameter_server_mode_stage_trainersr   r   _coordinator_endpoints_with_coordinatorr   _gloo)r2   is_collectiverJ   	__class__r   r   r3   <  s&   
zPaddleCloudRoleMaker.__init__c                 C   s   | j | d S re   )r   r   r   r   r   r   r   V     zPaddleCloudRoleMaker._barrierr   c                 C   s   | j ||S re   )r   r   r   r   r   r   r   Y  s   z PaddleCloudRoleMaker._all_gatherr   c                 C   s   | j |||S re   )r   r   r   r   r   r   r   \  r   z PaddleCloudRoleMaker._all_reducec                 C      | j s|   | jS )zL
        return the heter device that current heter worker is using
        )r   _generate_roler   r1   r   r   r   _heter_device_     z"PaddleCloudRoleMaker._heter_devicec                 C   r   )zQ
        return the heter device type that current heter worker is using
        )r   r   r   r1   r   r   r   _heter_device_typeg  r   z'PaddleCloudRoleMaker._heter_device_typec                 C   r   )z9
        return stage id of current heter worker
        )r   r   r   r1   r   r   r   _get_stage_ido  r   z"PaddleCloudRoleMaker._get_stage_idc                 C   r   )z2
        return trainer num of all stages
        )r   r   r   r1   r   r   r   _get_stage_trainersw  r   z(PaddleCloudRoleMaker._get_stage_trainersc                 C   r   )z"
        return stage num
        )r   r   r   r1   r   r   r   _get_num_stage  r   z#PaddleCloudRoleMaker._get_num_stagec                 C      | j s|   | jtjkS )z3
        whether current process is worker
        )r   r   r+   r   r   r1   r   r   r   r        zPaddleCloudRoleMaker._is_workerc                 C   r   )z3
        whether current process is server
        )r   r   r+   r   r   r1   r   r   r   r     r   zPaddleCloudRoleMaker._is_serverc                 C   r   re   )r   r   r+   r   r   r1   r   r   r   _is_coordinator  s   z$PaddleCloudRoleMaker._is_coordinatorc                 C   s$   | j s|   | jtjko| jdkS )z=
        whether current process is worker of rank 0
        r   )r   r   r+   r   r   r   r1   r   r   r   r     s   z%PaddleCloudRoleMaker._is_first_workerc                 C   r   )z-
        get index of current worker
        r   r   r   r1   r   r   r   r     r   z"PaddleCloudRoleMaker._worker_indexc                 C   r   )z-
        get index of current server
        r   r1   r   r   r   r     r   z"PaddleCloudRoleMaker._server_indexc                 C   r   )z+
        get index of current node
        r   r1   r   r   r   r-     r   zPaddleCloudRoleMaker._role_idc                 C   r   )z5
        retrun the current number of worker
        )r   r   _trainers_numr1   r   r   r   r.     r   z PaddleCloudRoleMaker._worker_numc                 C   s*   | j s|   |  durt|  S dS )z5
        return the current number of server
        Nr   )r   r   r   r   r1   r   r   r   r/     s   
z PaddleCloudRoleMaker._server_numc                 C   r   z1
        return the training node number
        r   r   
_nodes_numr1   r   r   r   r     r   zPaddleCloudRoleMaker._node_numc                 C   r   r   r   r1   r   r   r   _get_node_num  r   z"PaddleCloudRoleMaker._get_node_numc                 C   r   re   )r   r   _local_rankr1   r   r   r   _get_local_rank     z$PaddleCloudRoleMaker._get_local_rankc                 C   r   re   )r   r   _local_device_idsr1   r   r   r   _get_local_device_ids  r   z*PaddleCloudRoleMaker._get_local_device_idsc                 C   r   re   )r   r   _world_device_idsr1   r   r   r   _get_world_device_ids  r   z*PaddleCloudRoleMaker._get_world_device_idsc                 C   r   )z.
        get endpoint of all trainers
        )r   r   r   r1   r   r   r   r     r   z+PaddleCloudRoleMaker._get_trainer_endpointsc                 C   (   | j s|   | jtjksJ d| jS )Nz0get_trainer_endpoint should be called by trainer)r   r   r+   r   r   r   r1   r   r   r   _get_trainer_endpoint  s   z*PaddleCloudRoleMaker._get_trainer_endpointc                 C   s&   | j s|   | jg ksJ d| jS )zK
        Returns:
            string: all heter_trainers'endpoints
        z&Heter Worker Endpoints Not initialized)r   r   r   r1   r   r   r   _get_heter_worker_endpoints  s   z0PaddleCloudRoleMaker._get_heter_worker_endpointsc                 C   r   )zR
        Returns:
            int: corresponding heter_trainer's endpoint
        z<_get_heter_worker_endpoint should be invoked by heter worker)r   r   r+   r   r   r   r1   r   r   r   _get_heter_worker_endpoint  s   z/PaddleCloudRoleMaker._get_heter_worker_endpointc                 C   r   )z.
        get endpoint of all pservers
        )r   r   r   r1   r   r   r   r     r   z+PaddleCloudRoleMaker._get_pserver_endpointsc                 C   r   re   )r   r   r   r1   r   r   r   _get_coordinator_endpoints  r   z/PaddleCloudRoleMaker._get_coordinator_endpointsc                 C   .   | j s|   | jtjtjfv sJ d| jS ))
        invoked by heter worker
        zC_get_previous_trainers should be invoked by trainer or heter worker)r   r   r+   r   r   r   r   r1   r   r   r   _get_previous_trainers$     z+PaddleCloudRoleMaker._get_previous_trainersc                 C   r   )r   z?_get_next_trainers should be invoked by trainer or heter worker)r   r   r+   r   r   r   r   r1   r   r   r   _get_next_trainers0  r   z'PaddleCloudRoleMaker._get_next_trainersc                 C   r   )z
        Return True if indispensable environment for fleetrun is not found
        (use python-run to launch fleet-code directly)
        )r   r   r   r1   r   r   r   _is_non_distributed<  s   z(PaddleCloudRoleMaker._is_non_distributedc                 C   r   )z'
        get heter worker nums
        )r   r   _heter_trainers_numr1   r   r   r   _heter_worker_numE  r   z&PaddleCloudRoleMaker._heter_worker_numc                 C   r   )z9
        whether current process is heter worker
        )r   r   r+   r   r   r1   r   r   r   _is_heter_workerM  r   z%PaddleCloudRoleMaker._is_heter_workerc                 C   sn  t dd | _| jd u r'd| _d| _tj| _d| _d| _d| _	d | _
d| _d S | jd| _tdd | _| jd urA| jd| _ng | _t dd| _| jdkrUtd	 n
d| _| jd| _t d
d }|d u rmtdt|}t dd }|d u rtd|dvrtd|t dd}t dd}t dd}|dkr|d| _
d| _t| j
| _	|dkr|dv sJ dnz|d| _W n   td|dkr|dv sJ dnz|d| _W n   tdd| _d| _	|dkrtj}t dd }|d u rtdt|}| jr\t dd | _| jd u rtdt| j| _t dd | _| jd u r6td t| j| _t d!d | _| jd u rMtd"td#d$ td%| jD | _t d&d }|d u rktd't d(d }	|	d u rztd)d*|	|g}
|
| _ n|d+krtd, tj!}tt dd-}n|d.krtj"}t d&d }|d u rtd't d(d }	|	d u rtd)d*|	|g}
|
| _ | j#| j }n|d/krtj$}t dd | _| jd u rtdt| j| _t dd | _| jd u rtd t| j| _t d!d | _| jd u rtd"td0d$ td%| jD | _t d1d | _%| j%d u r;td2| j%d3v sEJ d4| j%d5krZt d6d-}d*| j%|f| _&| j&d7krot d8d-}d*| j%|f| _&t d&d }|d u r~td't d(d }	|	d u rtd)d*|	|g}
|
| _ |d#|
| }|| _|| _|| _td9d: | jD | _d S );NZPADDLE_PSERVERS_IP_PORT_LISTr4   r	   r   T,PADDLE_TRAINER_ENDPOINTSZPADDLE_COORDINATOR_ENDPOINTSz$fl-ps > coordinator address is null!PADDLE_TRAINERS_NUMz@Can not find PADDLE_TRAINERS_NUM, please check your environment.ZTRAINING_ROLEz:Can not find TRAINING_ROLE, please check your environment.)TRAINERPSERVERHETER_TRAINERr   ztTRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER or COORDINATOR, but get {}, please check your environment.Z&PADDLE_NEXT_HETER_TRAINER_IP_PORT_LISTZ*PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LISTZ%PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST)r   r   z*training_role should be trainer or pserverzCan not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' .)r   r   z0training_role should be heter trainer or pserverz{Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' .Fr   PADDLE_TRAINER_IDz>Can not find PADDLE_TRAINER_ID, please check your environment.ZSTAGE_IDz5Can not find STAGE_ID, please check your environment.Z	STAGE_NUMz6Can not find STAGE_NUM, please check your environment.ZPADDLE_STAGE_TRAINERS_NUMzFCan not find PADDLE_STAGE_TRAINERS_NUM, please check your environment.c                 S      g | ]}t |qS r   rx   .0xr   r   r   
<listcomp>      z0PaddleCloudRoleMaker._ps_env.<locals>.<listcomp>z\d+ZPADDLE_PORTz8Can not find PADDLE_PORT, please check your environment.ZPOD_IPz3Can not find POD_IP, please check your environment.ru   r   z>>> curr node is coordinator!0r   r   c                 S   r   r   r   r   r   r   r   r     r  ZHETER_DEVICE_TYPEz>Can not find HETER_DEVICE_TYPE, please check your environment.)r   gpuxpuz*HETER_DEVICE_TYPE should be cpu,gpu or xpur  ZFLAGS_selected_gpusr  ZFLAGS_selected_xpusc                 S      h | ]	}| d d qS ru   r   r   r   r   r   r   	<setcomp>E      z/PaddleCloudRoleMaker._ps_env.<locals>.<setcomp>)'rY   getenvr   r   r   r   r+   r   r   r   r   r   r   r   r   r   ri   r   r?   rx   r   r   r   r   r   r   r   r   tuplerefindallr[   r   r   r   r   r   r   r   )r2   trainers_numZtraining_roleZnext_heter_trainer_eplistZprevious_heter_trainer_eplistZall_heter_trainer_eplistrE   
current_idZcur_portZcur_ipZcurr_endpointZheter_device_idr   r   r   _ps_envU  s  

















zPaddleCloudRoleMaker._ps_envc                 C   s   t tdd| _tdd| _| jdksJ tj| _td| _	td| _
| j	d u r5d| _	| j	| _
d| _| j	d	| _	t| j	| _td
d }|d urWtdd }t || _tdd | j	D | _td| _td| _td| _d S )Nr   r  ZPADDLE_TRAINING_ROLEr   r   ZPADDLE_CURRENT_ENDPOINTz127.0.0.1:6170Tr   ZPADDLE_AUTO_PARALLEL_CONFIGr   c                 S   r  r  r  r   r   r   r   r  Y  r	  z7PaddleCloudRoleMaker._collective_env.<locals>.<setcomp>ZPADDLE_RANK_IN_NODEZPADDLE_LOCAL_DEVICE_IDSZPADDLE_WORLD_DEVICE_IDS)rx   rY   r
  r   _training_roler   r   r+   r   r   r   r   r   r   r   r   r   r   r   )r2   Z
auto_tunerr  r   r   r   _collective_envG  s(   


z$PaddleCloudRoleMaker._collective_envc              	   C   s  t tdd}|dvrd S t tdd}tdd}|tjjtjjtjjfvr/t| j	j
|dkr5dnd	}|tjjkrWtd
d}tdd}tdd}||||d}nO|tjjkrd	}	t }
|
 }d	|d< | jry| jd }|  rxd}	ntdd}|  r|  dkrd}	|d\}}||||	|d}ntdd}||d}|tjjkrd}n|tjjkrd}nd}td| d| d|  | j	j|| j|  |  |  ||d |tjjkrd	|d< d S d S )NZPADDLE_WITH_GLOOr  )r	   r
   ZPADDLE_GLOO_RENDEZVOUSZ
SYS_JOB_IDr4   r
   TFZPADDLE_GLOO_FS_NAMEZPADDLE_GLOO_FS_UGIZPADDLE_GLOO_FS_PATH)r6   r7   r8   r5   rh   r   ZPADDLE_GLOO_HTTP_ENDPOINTru   )r9   r:   r5   r;   r<   )r8   r5   r   r   r   zGloo init with z: need_init_all: z, args: )rD   rE   rF   rG   rH   rI   rJ   )rx   rY   r
  r   r   r   r   r   r?   r   r%   r   dictr   r   r   r   r   r   ri   rR   r+   r-   r.   r/   )r2   Zuse_glooZrendezvous_typera   rI   rL   rM   rN   rJ   r;   managerr<   Z	ep_rank_0rP   rQ   typer   r   r   
_gloo_init^  s   

zPaddleCloudRoleMaker._gloo_initc                 C   s@   | j s| js|   n|   d| _ t s|   dS dS dS z.
        generate role for role maker
        TN)r   r   r  r  paddleZin_dynamic_moder  r1   r   r   r   r     s   
z#PaddleCloudRoleMaker._generate_role)Fr   r   )+r   r   r   r   r3   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r-   r.   r/   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r   __classcell__r   r   r   r   r   #  sR    

	 sOr   c                       s:   e Zd ZdZd fdd	Zdd Zdd Zd	d
 Z  ZS )UserDefinedRoleMakeraB  
    UserDefinedRoleMaker is an interface for distributed configuration initialization based on obtaining distributed related information from user-defined parameters.

    Examples:
        .. code-block:: python

            >>> import paddle.distributed.fleet as fleet
            >>> from paddle.distributed.fleet.base.role_maker import Role

            >>> fleet.UserDefinedRoleMaker(
            ...     current_id=0,
            ...     role=Role.SERVER,
            ...     worker_num=2,
            ...     server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
    Fc                    s"   t  jd||d| || _d S )N)r   	init_gloor   )r   r3   Z
_init_gloo)r2   r   r  rJ   r   r   r   r3     s   
zUserDefinedRoleMaker.__init__c                 C   s   | j d| _| j dg | _| j dd| _| jdkr+t| jdks%J t| j| _| j d| _| j d| _| jtj	krOt| j| jkrO| j| j | _
n| jtjkr\| j| j | _
tdd | jD | _d S )	NZserver_endpointsworker_endpointsrG   r   rE   r  c                 S   r  r  r  r   r   r   r   r    r	  z<UserDefinedRoleMaker._user_defined_ps_env.<locals>.<setcomp>)r   r=   r   r   r   r   r+   r   r   r   r   r   r   r1   r   r   r   _user_defined_ps_env  s   
z)UserDefinedRoleMaker._user_defined_ps_envc                 C   sJ   | j d| _| j d| _t| j| _tj| _tdd | jD | _	d S )Nr  r  c                 S   r  r  r  r   r   r   r   r    r	  zDUserDefinedRoleMaker._user_defined_collective_env.<locals>.<setcomp>)
r   r=   r   r   r   r   r   r   r  r   r1   r   r   r   _user_defined_collective_env  s
   z1UserDefinedRoleMaker._user_defined_collective_envc                 C   s,   | j s| js|   n|   d| _ dS dS r  )r   r   r  r  r1   r   r   r   r     s   

z#UserDefinedRoleMaker._generate_role)FF)	r   r   r   r   r3   r  r  r   r  r   r   r   r   r    s    r  )r   rY   r  rk   r   multiprocessingr   r   numpyr   r  Zpaddle.baser   Z5paddle.distributed.fleet.base.private_helper_functionr   Z
backup_envr   __all__r   r   r   r   r  r   r   r   r   <module>   s2     ]      