o
    *j                     @   s
  d dl Z d dlZd dlZd dlZd dlmZ d dlm  mZ	 d dl
m  mZ d dlmZmZmZ G dd dejZdde j fddZd	d
 Zdd ZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )    N)DropPath	to_2tupletrunc_normal_c                       s4   e Zd ZdZddejdf fdd	Zdd Z  ZS )Mlpz Multilayer perceptron.N        c                    sN   t    |p|}|p|}t||| _| | _t||| _t|| _d S N)	super__init__nnLinearfc1actfc2Dropoutdrop)selfin_featureshidden_featuresZout_features	act_layerr   	__class__ c/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/modelscope/models/cv/vidt/backbone.pyr	      s   
zMlp.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r   )r   r   r   r   )r   xr   r   r   forward!   s   




zMlp.forward)	__name__
__module____qualname____doc__r
   GELUr	   r   __classcell__r   r   r   r   r      s    r   i'     c              	   C   s  |d }| }|j dtjd}|j dtjd}d}||ddddddf |  | }||ddddddf |  | }tj|tj| jd}	|d|	d  |  }	|dddddddf |	 }
|dddddddf |	 }tj|
dddddddddf  |
dddddddddf  fd	d
d}
tj|dddddddddf  |dddddddddf  fd	d
d}tj	||
fdd
}|S )aD   Masked Sinusoidal Positional Encoding

    Args:
        x: [PATCH] tokens
        mask: the padding mask for [PATCH] tokens
        num_pos_feats: the size of channel dimension
        temperature: the temperature value
        scale: the normalization scale

    Returns:
        pos: Sinusoidal positional encodings
    r!      )dtypegư>N)r#   devicer      dim   )
ZcumsumtorchZfloat32aranger%   stacksincosflattencat)r   maskZnum_pos_featsZtemperaturescaleZnot_maskZy_embedZx_embedepsZdim_tZpos_xZpos_yposr   r   r   masked_sin_pos_encoding*   s2   &&  JJr5   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    r   r"   r)   r!   r&      r$   )shapeviewpermute
contiguous)r   window_sizeBHWCwindowsr   r   r   window_partitionW   s   	rA   c                 C   sb   t | jd || | |  }| ||| || ||d}|dddddd |||d}|S )z
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    r   r$   r"   r)   r!   r&   r6   )intr7   r8   r9   r:   )r@   r;   r=   r>   r<   r   r   r   r   window_reverseh   s   $rC   c                       s:   e Zd ZdZ				d
 fdd	Z			ddd	Z  ZS )ReconfiguredAttentionModulea   Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM.
    It supports both of shifted and non-shifted window.

    !!!!!!!!!!! IMPORTANT !!!!!!!!!!!
    The original attention module in Swin is replaced with the reconfigured attention module in Section 3.
    All the Args are shared, so only the forward function is modified.
    See https://arxiv.org/pdf/2110.03921.pdf
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    TNr   c                    s  t    || _|| _|| _|| }|p|d | _tt	d|d  d d|d  d  || _
t| jd }	t| jd }
tt|	|
g}t|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |d}| d| tj||d |d| _t|| _t||| _t|| _t| j
d	d
 tjdd| _d S )Ng      r!   r   r"   r$   relative_position_indexr)   Zbias{Gz?Zstdr'   )r   r	   r(   r;   	num_headsr2   r
   	Parameterr*   zerosrelative_position_bias_tabler+   r,   Zmeshgridr/   r9   r:   sumZregister_bufferr   qkvr   	attn_dropproj	proj_dropr   ZSoftmaxsoftmax)r   r(   r;   rI   qkv_biasqk_scalerO   rQ   Zhead_dimZcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsrE   r   r   r   r	      sX   
	"(,
z$ReconfiguredAttentionModule.__init__Fc           *   	   C   s  | j d | j d ksJ | j d }|| }|sS|j\}}	}
}|	|
 }||||}tj||gdd}| |}|ddd|ddf |dd|dddf }}nh|d j\}}	}
}|	|
 }|d j\}}}}|| }|d |||}|d |||}tj|||gdd}| |}|ddd|ddf |dd||| ddf |dd|| dddf }}}|||	|
d}t||}|jd }|||| d| j|| j }|	ddddd}|d |d |d }}}|| j
 }||d	d }| j| jd | j d | j d  | j d | j d  d}|	ddd }||d }|durY|jd }||| || j||}|dd} ||  }|d| j||}| |}| |}|| dd||||}!||dd| j|| j }|	ddddd}|d |d |d }"}#}$|r||||d| j|| j }|ddddddddddddf 	dddddd
 }%|%d|| j|| d}%|%d |%d }&}'tj|&|#gddtj|'|$gdd}#}$|"| j
 }"|"|#d	d }(|dur|(| }(| |(}(| |(}(|(|$ dd|d|})t|!||	|
}!tj|!||	|
 ||)gdd}| |}| |}|ddd|	|
 ddf ||	|
|}!|dd|	|
 dddf })|!|)fS )al   Forward function.
        RAM module receives [Patch] and [DET] tokens and returns their calibrated ones

        Args:
            x: [PATCH] tokens
            det: [DET] tokens
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention

            "additional inputs for RAM"
            cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention)
            cross_attn_mask: mask for cross-attention

        Returns:
            patch_x: the calibrated [PATCH] tokens
            det_x: the calibrated [DET] tokens
        r   r"   r'   Nr$   r)   r!   r&   r6   )r;   r7   r8   r*   r0   rN   rA   ZreshaperI   r9   r2   	transposerL   rE   r:   	unsqueezerR   rO   rC   rP   rQ   )*r   r   detr1   
cross_attncross_attn_maskr;   Zlocal_map_sizer<   r=   r>   r?   NZfull_qkvZ	patch_qkvZdet_qkv_Zori_HZori_WZori_N	shifted_xZcross_xZcross_patch_qkvZB_Z
_patch_qkvZpatch_qZpatch_kZpatch_vZ
patch_attnZrelative_position_biasZnWZtmp0Ztmp1Zpatch_xZdet_qZdet_kZdet_vZpatch_kvZcross_patch_kZcross_patch_vZdet_attnZdet_xr   r   r   r      s   

8
R












 

*z#ReconfiguredAttentionModule.forward)TNr   r   )NFNr   r   r   r   r	   r   r    r   r   r   r   rD   z   s    2rD   c                
       sB   e Zd ZdZddddddddejejf
 fdd		Zd
d Z  Z	S )SwinTransformerBlocka]   Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
       r         @TNr   c              	      s   t    || _|| _|| _|| _|| _d| j  kr#| jk s(J d J d||| _t|t	| j||||	|d| _
|
dkrDt|
nt | _||| _t|| }t||||d| _d | _d | _d S )Nr   z shift_size must in 0-window_size)r;   rI   rS   rT   rO   rQ   r   )r   r   r   r   )r   r	   r(   rI   r;   
shift_size	mlp_rationorm1rD   r   attnr   r
   ZIdentity	drop_pathnorm2rB   r   mlpr=   r>   )r   r(   rI   r;   rb   rc   rS   rT   r   rO   rf   r   
norm_layerZmlp_hidden_dimr   r   r   r	   W  sB   
(



zSwinTransformerBlock.__init__c              	   C   s"  |j \}}}| j| j}	}
||	|
 | j ksJ d|}| |}|ddd|	|
 ddf |dd|	|
 dddf }}|||	|
|}|}d }}| j|
| j  | j }| j|	| j  | j }t|dd||||f}|j \}}}}|\}}| 	|}| j
dkrtj|| j
 | j
 fdd}|}n|}d}|r|| }|| }||f}n|| }|}| j|||||d\}}| j
dkrtj|| j
| j
fdd}n|}|dks|dkr|ddd|	d|
ddf  }|||	|
 |}tj||gdd}|| | }|| | | | }|S )	a   Forward function.

        Args:
            x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens
            H, W: Spatial resolution of the input feature.
            mask_matrix: Attention mask for cyclic shift.

            "additional inputs'
            pos: (patch_pos, det_pos)
            cross_attn: whether to use cross attn [det x [det + patch]]
            cross_attn_mask: attention mask for cross-attention

        Returns:
            x: calibrated & binded [PATCH, DET] tokens
        input feature has wrong sizeNr   )r"   r!   )Zshiftsdims)r1   rX   rY   rZ   r"   r'   )r7   r=   r>   det_token_numrd   r8   r;   Fpaddet_pos_linearrb   r*   Zrollre   r:   r0   rf   rh   rg   )r   r   Zmask_matrixr4   rY   rZ   r<   Lr?   r=   r>   ZshortcutrX   Zorig_xZpad_lZpad_tZpad_rZpad_br\   HpWp	patch_posdet_posr]   	attn_maskZcross_patchr   r   r   r     sb   
>




	
$zSwinTransformerBlock.forward)
r   r   r   r   r
   r   	LayerNormr	   r   r    r   r   r   r   r_   E  s    ,r_   c                       s0   e Zd ZdZejdf fdd	Zdd Z  ZS )PatchMergingz Patch Merging Layer

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    Tc                    sd   t    || _|rd| nd}tjd| |dd| _|d| | _tj||dd| _||| _d S )Nr!      r&   FrF   )	r   r	   r(   r
   r   	reductionnormZ	expansionrg   )r   r(   ri   expandZ
expand_dimr   r   r   r	     s   
zPatchMerging.__init__c              
   C   s  |j \}}}||| | j ksJ d|ddd|| ddf |dd|| dddf }}|||||}|d dkpE|d dk}|rXt|ddd|d d|d f}|ddddddddddf }	|ddddddddddf }
|ddddddddddf }|ddddddddddf }t|	|
||gd}||dd| }|ddd}tj||gdd}| |}| 	|}|S )	aV   Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens
            H, W: Spatial resolution of the input feature.

        Returns:
            x: merged [PATCH, DET] tokens;
            only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale
        rj   Nr!   r"   r   r$   r&   r'   )
r7   rl   r8   rm   rn   r*   r0   repeatrz   ry   )r   r   r=   r>   r<   rp   r?   rX   Z	pad_inputZx0x1Zx2Zx3r   r   r   r     s$   > $$$$

zPatchMerging.forward	r   r   r   r   r
   rv   r	   r   r    r   r   r   r   rw     s    rw   c                       sD   e Zd ZdZdddddddejdddf fdd		Zdd
dZ  ZS )
BasicLayera   A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of feature channels
        depth (int): Depths of this stage.
        num_heads (int): Number of attention head.
        window_size (int): Local window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    r`   ra   TNr   Fc                    s   t    	| _	d | _|| _| _|| _t 	f
ddt	|D | _
|d ur=|| d| _d S d | _d S )Nr!   c                    sP   g | ]$}t 	|d  dkrdn	d   ttr!| ndqS )r!   r   )r(   rI   r;   rb   rc   rS   rT   r   rO   rf   ri   )r_   
isinstancelist.0i
rO   r(   r   rf   rc   ri   rI   rT   rS   r;   r   r   
<listcomp>G  s$    
z'BasicLayer.__init__.<locals>.<listcomp>)r(   ri   r{   )r   r	   r;   rb   depthr(   use_checkpointr
   
ModuleListrangeblocks
downsample)r   r(   r   rI   r;   rc   rS   rT   r   rO   rf   ri   r   lastr   r   r   r   r	   /  s   

 

zBasicLayer.__init__c              	   C   s  |j d }tt|| j | j }tt|| j | j }	tjd||	df|jd}
td| j t| j | j	 t| j	 df}td| j t| j | j	 t| j	 df}d}|D ]}|D ]}||
dd||ddf< |d7 }q_q[t
|
| j}|d| j| j }|d|d }||dktd|dktd}|r|j dd \}}||kr||kstj|d  ||fd	tjd }t||| j}| }||dktd|dktd}|||| dd}tj|d| jfdd
}nd}d}||f}t| jD ]8\}}|||_|_|rd}|}|}nd}d}d|f}| jr7tj||||||d}q	||||||d}q	| jdurd| |||}|d d |d d }}||||||fS ||||||fS )aD   Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
            det_pos: pos encoding for det token
            input_mask: padding mask for inputs
            cross_attn: whether to use cross attn [det x [det + patch]]
        r   r"   )r%   Nr$   r!   g      Yr   size)valueTF)r4   rY   rZ   )r7   rB   npceilr;   r*   rK   r%   slicerb   rA   r8   rW   Zmasked_fillfloatrm   interpolatetoboolr5   r(   rn   rl   	enumerater   r=   r>   r   
checkpointr   )r   r   r=   r>   rt   
input_maskrY   r<   rq   rr   Zimg_maskZh_slicesZw_slicesZcnthwZmask_windowsru   Z_HZ_Wrs   rZ   r4   Zn_blkZblkZ_cross_attnZ_cross_attn_mask_posZx_downWhWwr   r   r   r   ^  s   





	

zBasicLayer.forward)Fr~   r   r   r   r   r     s    /r   c                       s2   e Zd ZdZ				d
 fdd	Zdd	 Z  ZS )
PatchEmbedaE   Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    r&   r)   `   Nc                    sX   t    t|}|| _|| _|| _tj||||d| _|d ur'||| _	d S d | _	d S )N)Zkernel_sizeZstride)
r   r	   r   
patch_sizein_chans	embed_dimr
   ZConv2drP   rz   )r   r   r   r   ri   r   r   r   r	     s   

zPatchEmbed.__init__c              
   C   s   |  \}}}}|| jd  dkr#t|d| jd || jd   f}|| jd  dkr@t|ddd| jd || jd   f}| |}| jdurp| d| d}}|ddd}| |}|ddd| j	||}|S )zForward function.r"   r   Nr!   r)   r$   )
r   r   rm   rn   rP   rz   r/   rV   r8   r   )r   r   r\   r=   r>   r   r   r   r   r   r     s   $


zPatchEmbed.forward)r&   r)   r   Nr^   r   r   r   r   r     s    
r   c                       s   e Zd ZdZddddg dg ddd	d
ddddejdd
g dddf fdd	Zdd Zej	j
dd ZdddgfddZdd Zd! fdd	Zdd  Z  ZS )"SwinTransformera   Swin Transformer backbone.
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        pretrain_img_size (int): Input image size for training the pretrained model,
            used in absolute position embedding. Default 224.
        patch_size (int | tuple(int)): Patch size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        depths (tuple[int]): Depths of each Swin Transformer stage.
        num_heads (tuple[int]): Number of attention head of each stage.
        window_size (int): Window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate. Default: 0.
        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
        out_indices (Sequence[int]): Output from which stages.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any args.
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
       r&   r)   r   )r!   r!      r!   )r)   r         r`   ra   TNr   g?F)r"   r!   r)   r$   c                    s  t    || _t|| _ | _|| _|| _|| _|| _	t
|| | jr%|nd d| _| jr\t|}t|}|d |d  |d |d  g}ttd |d |d | _t| jdd tj|d| _dd td|t|D }t | _t| jD ]F}tt d	|  || || |||	|
|||t|d | t|d |d   ||| jk rtnd || jd k rd nd
|d}| j| q{ fddt| jD }|| _|D ]}||| }d| }| || q|    d S )N)r   r   r   ri   r   r"   rG   rH   )pc                 S   s   g | ]}|  qS r   )item)r   r   r   r   r   r   R  s    z,SwinTransformer.__init__.<locals>.<listcomp>r!   T)r(   r   rI   r;   rc   rS   rT   r   rO   rf   ri   r   r   r   c                    s   g | ]
}t  d |  qS )r!   )rB   r   r   r   r   r   m  s    rz   )!r   r	   pretrain_img_sizelen
num_layersr   ape
patch_normout_indicesfrozen_stagesr   patch_embedr   r
   rJ   r*   rK   absolute_pos_embedr   r   pos_dropZlinspacerM   r   layersr   r   rB   rw   appendnum_featuresZ
add_module_freeze_stages)r   r   r   r   r   ZdepthsrI   r;   rc   rS   rT   Z	drop_rateZattn_drop_rateZdrop_path_rateri   r   r   r   r   r   patches_resolutionZdprZi_layerlayerr   Z
layer_namer   r   r   r	     st   


&

zSwinTransformer.__init__c                 C   s   | j dkr| j  | j D ]}d|_q| j dkr!| jr!d| j_| j dkrI| j  td| j d D ]}| j	| }|  | D ]}d|_qBq3d S d S )Nr   Fr"   r!   )
r   r   eval
parametersZrequires_gradr   r   r   r   r   )r   paramr   mr   r   r   r   z  s    




zSwinTransformer._freeze_stagesc                 C   s   ddhS )Ndet_pos_embed	det_tokenr   r   r   r   r   no_weight_decay  s   zSwinTransformer.no_weight_decayd   rx   c                    s   | _ | _ttd| jd  _t jdd _| _	td||}t|dd}tj| _
 fddtt jd D  _|dkrO j j	 | _dt jt j   _ jD ]}||_|jd	urp||j_|jD ]}||_t||j|_qsqb|d
krd	 jd _d	S d	S )a*   A function to add necessary (leanable) variables to Swin Transformer for object detection

            Args:
                method: vidt or vidt_wo_neck
                det_token_num: the number of object to detect, i.e., number of object queries
                pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens
                cross_indices: the indices where to use the [DET X PATCH] cross-attention
                    there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper.
        r"   r   rG   rH   c                    s   g | ]	} j |d   qS )r"   )r   r   r   r   r   r     s    z0SwinTransformer.finetune_det.<locals>.<listcomp>Zvidtr!   NZvidt_wo_neckr$   )methodrl   r
   rJ   r*   rK   r   r   r   pos_dimr   r   r   Znum_channelsr   cross_indicesr   mask_divisorr   r   r   r(   ro   )r   r   rl   r   r   r   r   blockr   r   r   finetune_det  s<   



zSwinTransformer.finetune_detc              	   C   s  |j d |j d |j d }}}| |}|d|d}}|ddd}| |}| j|dd}| j}t	j
|d  || j || j fdtjd }g }	t| jD ]l}
| j|
 }|
| jv rhdnd	}tj||gdd
}|||||||d\}}}}}}|ddd| j ddf |dd| j dddf }}|
dkr|ddd| j ddf |||ddddd}|	| qZ|	||||ddddd |dd| j dddf ddd}|ddd}|	\}}}}||||||fS )a   Forward function.

            Args:
                x: input rgb images
                mask: input padding masks [0: rgb values, 1: padded values]

            Returns:
                patch_outs: multi-scale [PATCH] tokens (four scales are used)
                    these tokens are the first input of the neck decoder
                det_tgt: final [DET] tokens obtained at the last stage
                    this tokens are the second input of the neck decoder
                det_pos: the learnable pos encoding for [DET] tokens.
                    these encodings are used to generate reference points in deformable attention
        r   r!   r)   r"   r$   Nr   TFr'   )r   rt   rY   )r7   r   r   r/   rV   r   r   r{   r   rm   r   r   r   r   r*   r   r   r   r   r   r0   rl   r8   r9   r   )r   r   r1   r<   r\   r   r   r   rt   Z
patch_outsZstager   rY   Zx_outr=   r>   Z	patch_outZdet_tgtZ
features_0Z
features_1Z
features_2Z
features_3r   r   r   r     s\   "




&	

"(zSwinTransformer.forwardc                    s   t t| | |   dS )z?Convert the model into training mode while keep layers freezed.N)r   r   trainr   )r   moder   r   r   r     s   zSwinTransformer.trainc                 C   sn   d}|| j  7 }t| jD ]
\}}|| 7 }q|| j| jd  | jd  d| j  7 }|| j| j 7 }|S )Nr   r"   r!   )r   flopsr   r   r   r   r   Znum_classes)r   r   r   r   r   r   r   r     s   
zSwinTransformer.flops)T)r   r   r   r   r
   rv   r	   r   r*   Zjitignorer   r   r   r   r   r    r   r   r   r   r     s@    `

9Or   )mathosnumpyr   r*   Ztorch.nnr
   Ztorch.nn.functionalZ
functionalrm   Ztorch.utils.checkpointutilsr   Ztimm.models.layersr   r   r   Moduler   pir5   rA   rC   rD   r_   rw   r   r   r   r   r   r   r   <module>   s.   
- L > 02