o
    )jD                     @   s  d Z ddlZddlZddlmZ ddlm  mZ ddlm	Z	 ddl
mZmZ ddlmZmZ ddlmZmZ dZG d	d
 d
ejZG dd dejZdd ZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZejej e	j!dG dd deZ"dS )zLibrary to support dual-path speech separation.

Authors
 * Cem Subakan 2020
 * Mirco Ravanelli 2020
 * Samuele Cornell 2020
 * Mirko Bronzi 2020
 * Jianyuan Zhong 2020
    N)Models)MODELS
TorchModel)	ModelFileTasks   )MossformerBlockGFSMNScaledSinuEmbedding:0yE>c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )	GlobalLayerNorma3  Calculate Global Layer Normalization.

    Args:
       dim : (int or list or torch.Size)
           Input shape from an expected input of size.
       eps : float
           A value added to the denominator for numerical stability.
       elementwise_affine : bool
          A boolean value that when set to True,
          this module has learnable per-element affine parameters
          initialized to ones (for weights) and zeros (for biases).

    Example:
    >>> x = torch.randn(5, 10, 20)
    >>> GLN = GlobalLayerNorm(10, 3)
    >>> x_norm = GLN(x)
    r
   Tc                    s   t t|   || _|| _|| _| jrM|dkr-tt	| jd| _
tt| jd| _|dkrKtt	| jdd| _
tt| jdd| _d S d S | dd  | dd  d S )N   r      weightbias)superr   __init__dimepselementwise_affinenn	ParametertorchZonesr   Zzerosr   Zregister_parameter)selfr   shaper   r   	__class__ q/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/modelscope/models/audio/separation/m2/mossformer.pyr   /   s   zGlobalLayerNorm.__init__c                 C   s   |  dkr<tj|ddd}tj|| d ddd}| jr0| j||  t|| j  | j }n|| t|| j  }|  dkrytj|ddd}tj|| d ddd}| jrm| j||  t|| j  | j }|S || t|| j  }|S )zReturns the normalized tensor.

        Args:
            x : torch.Tensor
                Tensor of size [N, C, K, S] or [N, C, L].
        r   )r      T)Zkeepdimr   r   )r   r   r   )r   r   meanr   r   sqrtr   r   )r   xr   varr   r   r   forward@   s$   zGlobalLayerNorm.forward)r
   T__name__
__module____qualname____doc__r   r#   __classcell__r   r   r   r   r      s    r   c                       s.   e Zd ZdZd fdd	Z fddZ  ZS )CumulativeLayerNormaD  Calculate Cumulative Layer Normalization.

    Args:
       dim : int
        Dimension that you want to normalize.
       elementwise_affine : True
        Learnable per-element affine parameters.

    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> CLN = CumulativeLayerNorm(10)
    >>> x_norm = CLN(x)
    Tc                    s   t t| j||dd d S )Nr
   )r   r   )r   r*   r   )r   r   r   r   r   r   r   s   s   

zCumulativeLayerNorm.__init__c                    sx   |  dkr |dddd }t |}|dddd }|  dkr:t|dd}t |}t|dd}|S )zReturns the normalized tensor.

        Arguments
        ---------
        x : torch.Tensor
            Tensor size [N, C, K, S] or [N, C, L]
        r   r   r   r   r   )r   permute
contiguousr   r#   r   	transposer   r!   r   r   r   r#   w   s   
zCumulativeLayerNorm.forward)Tr$   r   r   r   r   r*   c   s    r*   c                 C   sL   | dkrt ||ddS | dkrt|ddS | dkr!tjd|ddS t|S )	z5Just a wrapper to select the normalization type.
    ZglnT)r   clnlnr   r
   r   )r   r*   r   Z	GroupNormZBatchNorm1d)normr   r   r   r   r   select_norm   s   
r3   c                       s*   e Zd ZdZd	 fdd	Zdd Z  ZS )
Encodera  Convolutional Encoder Layer.

    Args:
        kernel_size : int
            Length of filters.
        in_channels : int
            Number of  input channels.
        out_channels : int
            Number of output channels.

    Example:
    >>> x = torch.randn(2, 1000)
    >>> encoder = Encoder(kernel_size=4, out_channels=64)
    >>> h = encoder(x)
    >>> h.shape
    torch.Size([2, 64, 499])
    r   @   r   c                    s4   t t|   tj||||d ddd| _|| _d S )Nr   r   F)in_channelsout_channelskernel_sizestridegroupsr   )r   r4   r   r   Conv1dconv1dr6   )r   r8   r7   r6   r   r   r   r      s   
zEncoder.__init__c                 C   s0   | j dkrtj|dd}| |}t|}|S )a  Return the encoded output.

        Args:
            x : torch.Tensor
                Input tensor with dimensionality [B, L].

        Returns:
            x : torch.Tensor
                Encoded tensor with dimensionality [B, N, T_out].
                where B = Batchsize
                      L = Number of timepoints
                      N = Number of filters
                      T_out = Number of timepoints at the output of the encoder
        r   r   )r6   r   	unsqueezer<   FZrelur.   r   r   r   r#      s
   


zEncoder.forward)r   r5   r   r$   r   r   r   r   r4      s    r4   c                       s,   e Zd ZdZ fddZ fddZ  ZS )Decodera  A decoder layer that consists of ConvTranspose1d.

    Args:
    kernel_size : int
        Length of filters.
    in_channels : int
        Number of  input channels.
    out_channels : int
        Number of output channels.


    Example:
    ---------
    >>> x = torch.randn(2, 100, 1000)
    >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
    >>> h = decoder(x)
    >>> h.shape
    torch.Size([2, 1003])
    c                    s   t t| j|i | d S )N)r   r@   r   )r   argskwargsr   r   r   r      s   zDecoder.__init__c                    sr   |  dvrtd| jt |  dkr|nt|d}t|  dkr2tj|dd}|S t|}|S )a  Return the decoded output.

        Args:
            x : torch.Tensor
                Input tensor with dimensionality [B, N, L].
                    where, B = Batchsize,
                           N = number of filters
                           L = time points
        )r   r   z{} accept 3/4D tensor as inputr   r   r=   )	r   RuntimeErrorformatr%   r   r#   r   r>   Zsqueezer.   r   r   r   r#      s   $
zDecoder.forwardr$   r   r   r   r   r@      s    r@   c                       s6   e Zd ZdZ						d fdd		Zd
d Z  ZS )MossFormerMa9  This class implements the transformer encoder.

    Args:
        num_blocks : int
            Number of mossformer blocks to include.
        d_model : int
            The dimension of the input embedding.
        attn_dropout : float
            Dropout for the self-attention (Optional).
        group_size: int
            the chunk size
        query_key_dim: int
            the attention vector dimension
        expansion_factor: int
            the expansion factor for the linear projection in conv module
        causal: bool
            true for causal / false for non causal

    Example:
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
    >>> output, _ = net(x)
    >>> output.shape
    torch.Size([8, 60, 512])
    NF            @皙?c              	      s6   t    t|||||||d| _tj|dd| _d S )N)r   depth
group_sizequery_key_dimexpansion_factorcausalattn_dropoutgư>r1   )r   r   r   mossformerMr   	LayerNormr2   )r   
num_blocksd_modelrN   rK   rL   rM   rO   r   r   r   r   "  s   
zMossFormerM.__init__c                 C   s   |  |}| |}|S )a  
        Args:
            src : torch.Tensor
                Tensor shape [B, L, N],
                where, B = Batchsize,
                       L = time points
                       N = number of filters
                The sequence to the encoder layer (required).
        )rP   r2   )r   srcoutputr   r   r   r#   6  s   


zMossFormerM.forward)NFrF   rG   rH   rI   r$   r   r   r   r   rE     s    rE   c                       s.   e Zd ZdZ		d fdd	Zdd Z  ZS )	ComputationBlocka&  Computation block for dual-path processing.

    Args:
        num_blocks : int
            Number of mossformer blocks to include.
         out_channels : int
            Dimensionality of inter/intra model.
         norm : str
            Normalization type.
         skip_around_intra : bool
            Skip connection around the intra layer.

    Example:
    ---------
        >>> comp_block = ComputationBlock(64)
        >>> x = torch.randn(10, 64, 100)
        >>> x = comp_block(x)
        >>> x.shape
        torch.Size([10, 64, 100])
    r0   Tc                    sF   t t|   t||d| _|| _|| _|d ur!t||d| _d S d S )N)rR   rS   r   )	r   rV   r   rE   	intra_mdlskip_around_intrar2   r3   
intra_norm)r   rR   r7   r2   rX   r   r   r   r   \  s   zComputationBlock.__init__c                 C   sd   |j \}}}|ddd }| |}|ddd }| jdur'| |}| jr.|| }|}|S )ad  Returns the output tensor.

        Args:
            x : torch.Tensor
                Input tensor of dimension [B, N, S].

        Returns:
            out: torch.Tensor
                Output tensor of dimension [B, N, S].
                where, B = Batchsize,
                   N = number of filters
                   S = sequence time index
        r   r   r   N)r   r+   r,   rW   r2   rY   rX   )r   r!   BNSZintraoutr   r   r   r#   o  s   


zComputationBlock.forward)r0   Tr$   r   r   r   r   rV   F  s    rV   c                       s6   e Zd ZdZ						d fdd	Zd	d
 Z  ZS )MossFormerMaskNeta  The dual path model which is the basis for dualpathrnn, sepformer, dptnet.

    Args:
        in_channels : int
            Number of channels at the output of the encoder.
        out_channels : int
            Number of channels that would be inputted to the intra and inter blocks.
        norm : str
            Normalization type.
        num_spks : int
            Number of sources (speakers).
        skip_around_intra : bool
            Skip connection around intra.
        use_global_pos_enc : bool
            Global positional encodings.
        max_length : int
            Maximum sequence length.

    Example:
    ---------
    >>> mossformer_block = MossFormerM(1, 64, 8)
    >>> mossformer_masknet = MossFormerMaskNet(64, 64, intra_block, num_spks=2)
    >>> x = torch.randn(10, 64, 2000)
    >>> x = mossformer_masknet(x)
    >>> x.shape
    torch.Size([2, 10, 64, 2000])
       r0   r   T N  c	           	         s   t t|   || _|| _t||d| _tj||ddd| _	|| _
| j
r)t|| _t||||d| _tj||| dd| _tj||ddd| _t | _t | _tt||dt | _tt||dt | _d S )Nr   r   F)r   )rX   )r8   )r   r^   r   num_spksrR   r3   r2   r   r;   conv1d_encoderuse_global_pos_encr	   pos_encrV   mdl
conv1d_outconv1_decoderZPReLUpreluZReLU
activationZ
SequentialZTanhrU   ZSigmoidoutput_gate)	r   r6   r7   rR   r2   ra   rX   rc   
max_lengthr   r   r   r     s<   




zMossFormerMaskNet.__init__c           	      C   s   |  |}| |}| jr$|}|dd}| |}|dd}|| }| |}| |}| |}|j\}}}|	|| j
 d|}| || | }| |}|j\}}}|	|| j
||}| |}|dd}|S )a  Returns the output tensor.

        Args:
            x : torch.Tensor
                Input tensor of dimension [B, N, S].

        Returns:
            out : torch.Tensor
                Output tensor of dimension [spks, B, N, S]
                where, spks = Number of speakers
                   B = Batchsize,
                   N = number of filters
                   S = the number of time frames
        r   r   )r2   rb   rc   r-   rd   re   rh   rf   r   viewra   rU   rj   rg   ri   )	r   r!   baseZembrZ   _r\   r[   Lr   r   r   r#     s(   







zMossFormerMaskNet.forward)r_   r0   r   TTr`   r$   r   r   r   r   r^     s     )r^   )module_namec                       sL   e Zd ZdZ									dd	ef fd
dZdd ZdddZ  ZS )MossFormer2ziLibrary to support MossFormer speech separation.

    Args:
        model_dir (str): the model path.
       r_      r0   r   Tr`   	model_dirc              
      sf   t  j|g|R i | || _t||dd| _t|||||||	|
d| _t|d||d dd| _d S )Nr   )r8   r7   r6   )r6   r7   rR   r2   ra   rX   rc   rk   r   F)r6   r7   r8   r9   r   )	r   r   ra   r4   encr^   mask_netr@   dec)r   ru   r6   r7   rR   r8   r2   ra   rX   rc   rk   rA   rB   r   r   r   r     s,   
zMossFormer2.__init__c                    s     |} |}t|g j }|| tj fddt jD dd}|d}|d}||krEt	|ddd|| f}|S |d d d |d d f }|S )Nc                    s    g | ]}  | d qS )rl   )rx   r>   ).0ir   Zsep_xr   r   
<listcomp>C  s     z'MossFormer2.forward.<locals>.<listcomp>rl   r=   r   r   )
rv   rw   r   stackra   catrangesizer?   pad)r   inputr!   maskZ
est_sourceZT_originZT_estr   r{   r   r#   ;  s   



zMossFormer2.forwardNc                 C   s@   |s| j }|std}| jtjtj|tj	|ddd d S )Ncpu)Zmap_locationF)strict)
ru   r   deviceZload_state_dictloadospathjoinr   ZTORCH_MODEL_FILE)r   Z	load_pathr   r   r   r   load_check_pointN  s   

zMossFormer2.load_check_point)	rs   rs   r_   rt   r0   r   TTr`   )NN)	r%   r&   r'   r(   strr   r#   r   r)   r   r   r   r   rr     s     "rr   )#r(   r   r   Ztorch.nnr   Ztorch.nn.functionalZ
functionalr?   Zmodelscope.metainfor   Zmodelscope.modelsr   r   Zmodelscope.utils.constantr   r   Zmossformer_blockr   r	   ZEPSModuler   rQ   r*   r3   r4   ZConvTranspose1dr@   rE   rV   r^   Zregister_moduleZspeech_separationZ)speech_mossformer2_separation_temporal_8krr   r   r   r   r   <module>   s.   
G-8/AK~