o
    *j                     @   s8  d Z ddlmZmZmZ ddlZddlZddlZddlZddl	Z	ddl
Z
ddlZddlZddlmZ ddlmZ ddlmZ ddlmZ e ZejZejZdd	 Zd
d ZeejjjedZG dd dejZ G dd dejZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd dejZ%G dd dejZ&G dd dejZ'G dd dejZ(G dd  d ejZ)G d!d" d"ejZ*G d#d$ d$ejZ+G d%d& d&ejZ,G d'd( d(ejZ-G d)d* d*ejZ.G d+d, d,ejZ/G d-d. d.ejZ0G d/d0 d0ejZ1G d1d2 d2ejZ2G d3d4 d4e2Z3G d5d6 d6ejZ4dS )7zPyTorch BERT model.    )absolute_importdivisionprint_functionN)nn)SpaceTCnConfig)	ModelFile)
get_loggerc                 C   s    | d dt | td   S )zImplementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    g      ?      ?g       @)torcherfmathsqrtx r   j/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/modelscope/models/nlp/space_T_cn/backbone.pygelu(   s    r   c                 C   s   | t |  S N)r
   Zsigmoidr   r   r   r   swish0   s   r   )r   relur   c                       s&   e Zd Zd fdd	Zdd Z  ZS )BertLayerNorm-q=c                    s<   t t|   tt|| _tt|| _	|| _
dS )zWConstruct a layernorm module in the TF style (epsilon inside the square root).
        N)superr   __init__r   	Parameterr
   Zonesweightzerosbiasvariance_epsilon)selfhidden_sizeeps	__class__r   r   r   9   s   
zBertLayerNorm.__init__c                 C   sN   |j ddd}|| dj ddd}|| t|| j  }| j| | j S )NT)Zkeepdim   )meanpowr
   r   r   r   r   )r   r   usr   r   r   forwardA   s   zBertLayerNorm.forward)r   __name__
__module____qualname__r   r*   __classcell__r   r   r"   r   r   7   s    r   c                       sF   e Zd ZdZ fddZ														dddZ  ZS )BertEmbeddingszLConstruct the embeddings from word, position and token_type embeddings.
    c                    s   t t|   t|j|j| _t|j|j| _	t|j
|j| _td|j| _td|j| _t|jdd| _t|j| _d S )N      r   r!   )r   r0   r   r   	EmbeddingZ
vocab_sizer    word_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddingsmatch_type_embeddingstype_embeddingsr   	LayerNormDropouthidden_dropout_probdropoutr   configr"   r   r   r   L   s   zBertEmbeddings.__init__Nc              	   C   sr  | d}tj|tj|jd}|d|}|d u r t|}| |}| |}|d ur|d urt	
||	    }t	j
|td|	    }t|D ]0\}}| D ]'\}}|| | }|dkriqZtj|||d |d d f dd|||d d f< qZqR| |}| |}|| | }|d ur| |}||7 }|d ur| |}||7 }| |}| |}|S )N   )dtypedevicer   rA   dim)sizer
   ZarangelongrB   	unsqueezeZ	expand_as
zeros_liker5   nparraycpunumpytolistobject	enumerateitemsr&   r6   r7   r8   r9   r:   r=   )r   	input_ids
header_idstoken_type_idsmatch_type_idsl_hs
header_lenZtype_idxcol_dict_listidsheader_flatten_tokensheader_flatten_indexheader_flatten_outputtoken_column_idtoken_column_maskcolumn_start_indexheaders_lengthZ
seq_lengthZposition_idsZwords_embeddingsZheader_embeddingsZbiZcol_dictZkivilengthr6   r7   
embeddingsr8   r9   r   r   r   r*   \   sN   











zBertEmbeddings.forward)NNNNNNNNNNNNNNr,   r-   r.   __doc__r   r*   r/   r   r   r"   r   r0   H   s$    r0   c                       s.   e Zd Z fddZdd ZdddZ  ZS )	BertSelfAttentionc                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _d S )Nr   LThe hidden size (%d) is not a multiple of the number of attention heads (%d))r   rf   r   r    num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   Linearquerykeyvaluer;   attention_probs_dropout_probr=   r>   r"   r   r   r      s    
zBertSelfAttention.__init__c                 C   6   |  d d | j| jf }|j| }|ddddS Nr$   r   r%   r@      rF   rh   rk   viewpermuter   r   Znew_x_shaper   r   r   transpose_for_scores   
   
z&BertSelfAttention.transpose_for_scoresNc                 C   s   |  |}| |}| |}| |}| |}| |}	t||dd}
|
t| j	 }
|
| }
t
jdd|
}| |}t||	}|dddd }| d d | jf }|j| }|S )Nr$   rD   r   r%   r@   rt   )rn   ro   rp   ry   r
   matmul	transposer   r   rk   r   Softmaxr=   rw   
contiguousrF   rl   rv   )r   hidden_statesattention_maskschema_link_matrixmixed_query_layermixed_key_layermixed_value_layerquery_layer	key_layervalue_layerattention_scoresattention_probscontext_layernew_context_layer_shaper   r   r   r*      s,   








zBertSelfAttention.forwardr   r,   r-   r.   r   ry   r*   r/   r   r   r"   r   rf      s    rf   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )!BertSelfAttentionWithRelationsRATzd
    Adapted from https://github.com/microsoft/rat-sql/blob/master/ratsql/models/transformer.py
    c                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	d|j|j | _t	d|j|j | _d S Nr   rg      )r   r   r   r    rh   ri   rj   rk   rl   r   rm   rn   ro   rp   r;   rq   r=   r4   relation_k_embrelation_v_embr>   r"   r   r   r      s,   

z*BertSelfAttentionWithRelationsRAT.__init__c                 C   rr   rs   ru   rx   r   r   r   ry      rz   z6BertSelfAttentionWithRelationsRAT.transpose_for_scoresc                 C   sL  |  |}| |}| |}| |}| |}| |}	| |}
| |}t|	|
dd}|dd}|		dddd}t||}|	dddd}|| t
| j }|| }tjdd|}| |}t||}|	dddd}t||}|	dddd}|| }|	dddd }| dd | jf }|j| }|S )	7
        relation is [batch, seq len, seq len]
        r$   r{   r   r%   r@   rt   rD   N)rn   ro   rp   r   r   ry   r
   r|   r}   rw   r   r   rk   r   r~   r=   r   rF   rl   rv   )r   r   r   relationr   r   r   Z
relation_kZ
relation_vr   r   r   r   Zrelation_k_tZquery_layer_tZrelation_attention_scoresZrelation_attention_scores_tmerged_attention_scoresr   r   Zattention_probs_tZcontext_relationZcontext_relation_tZmerged_context_layerr   r   r   r   r*      sp   




z)BertSelfAttentionWithRelationsRAT.forward)r,   r-   r.   re   r   ry   r*   r/   r   r   r"   r   r      s
    r   c                       ,   e Zd Z fddZdd Zdd Z  ZS ))BertSelfAttentionWithRelationsTableformerc                    s   t t|   |j|j dkrtd|j|jf |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	d| j| _d S r   )r   r   r   r    rh   ri   rj   rk   rl   r   rm   rn   ro   rp   r;   rq   r=   r4   schema_link_embeddingsr>   r"   r   r   r   ;  s"   
z2BertSelfAttentionWithRelationsTableformer.__init__c                 C   rr   rs   ru   rx   r   r   r   ry   N  rz   z>BertSelfAttentionWithRelationsTableformer.transpose_for_scoresc                 C   s   |  |}| |}| |}| |}|dddd}| |}| |}	| |}
t||	dd}|t	
| j }|| }|| }tjdd|}| |}t||
}|dddd }| dd | jf }|j| }|S )	r   r   rt   r@   r%   r$   r{   rD   N)rn   ro   rp   r   rw   ry   r
   r|   r}   r   r   rk   r   r~   r=   r   rF   rl   rv   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r*   T  sF   




z1BertSelfAttentionWithRelationsTableformer.forwardr   r   r   r"   r   r   9  s    r   c                       $   e Zd Z fddZdd Z  ZS )BertSelfOutputc                    sB   t t|   t|j|j| _t|jdd| _t	|j
| _d S Nr   r3   )r   r   r   r   rm   r    denser   r:   r;   r<   r=   r>   r"   r   r   r        zBertSelfOutput.__init__c                 C   &   |  |}| |}| || }|S r   r   r=   r:   r   r   input_tensorr   r   r   r*        

zBertSelfOutput.forwardr+   r   r   r"   r   r         r   c                       (   e Zd Zd fdd	ZdddZ  ZS )	BertAttentionnonec                    sR   t t|   |dkrt|| _|dkrt|| _|dkr"t|| _t|| _d S )Nr   ratadd)	r   r   r   rf   r   r   r   r   outputr   r?   schema_link_moduler"   r   r   r     s   


zBertAttention.__init__Nc                 C   s   |  |||}| ||}|S r   )r   r   )r   r   r   r   Zself_outputattention_outputr   r   r   r*     s
   zBertAttention.forwardr   r   r+   r   r   r"   r   r     s    
r   c                       r   )BertIntermediatec                    sH   t t|   t|j|j| _t|j	t
rt|j	 | _d S |j	| _d S r   )r   r   r   r   rm   r    intermediate_sizer   
isinstance
hidden_actstrACT2FNintermediate_act_fnr>   r"   r   r   r     s   
zBertIntermediate.__init__c                 C   s   |  |}| |}|S r   )r   r   r   r   r   r   r   r*     s   

zBertIntermediate.forwardr+   r   r   r"   r   r     r   r   c                       r   )
BertOutputc                    sB   t t|   t|j|j| _t|jdd| _	t
|j| _d S r   )r   r   r   r   rm   r   r    r   r   r:   r;   r<   r=   r>   r"   r   r   r     r   zBertOutput.__init__c                 C   r   r   r   r   r   r   r   r*     r   zBertOutput.forwardr+   r   r   r"   r   r     r   r   c                       r   )		BertLayerr   c                    s4   t t|   t||d| _t|| _t|| _d S Nr   )	r   r   r   r   	attentionr   intermediater   r   r   r"   r   r   r     s   
zBertLayer.__init__Nc                 C   s(   |  |||}| |}| ||}|S r   )r   r   r   )r   r   r   r   r   Zintermediate_outputZlayer_outputr   r   r   r*     s   
zBertLayer.forwardr   r   r+   r   r   r"   r   r     s    r   c                       s(   e Zd Z fddZ	dddZ  ZS )SqlBertEncoderc                    s8   t t|   t| t fddt|D | _d S )Nc                       g | ]}t  qS r   copydeepcopy.0_layerr   r   
<listcomp>      z+SqlBertEncoder.__init__.<locals>.<listcomp>)r   r   r   r   r   
ModuleListranger   )r   Zlayersr?   r"   r   r   r     s
   
zSqlBertEncoder.__init__Tc                 C   s:   g }| j D ]}|||}|r|| q|s|| |S r   r   append)r   r   r   output_all_encoded_layersall_encoder_layerslayer_moduler   r   r   r*     s   



zSqlBertEncoder.forward)Tr+   r   r   r"   r   r     s    	r   c                       s.   e Zd Zd fdd	Z			d	ddZ  ZS )
BertEncoderr   c                    s>   t t|   t||d t fddt|jD | _d S )Nr   c                    r   r   r   r   r   r   r   r     r   z(BertEncoder.__init__.<locals>.<listcomp>)	r   r   r   r   r   r   r   Znum_hidden_layersr   r   r"   r   r   r     s
   
zBertEncoder.__init__NTc                 C   s<   g }| j D ]}||||}|r|| q|s|| |S r   r   )r   r   r   all_schema_link_matrixall_schema_link_maskr   r   r   r   r   r   r*     s   


zBertEncoder.forwardr   )NNTr+   r   r   r"   r   r     s    	r   c                       r   )
BertPoolerc                    s.   t t|   t|j|j| _t | _d S r   )	r   r   r   r   rm   r    r   ZTanh
activationr>   r"   r   r   r     s   zBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r   Zfirst_token_tensorpooled_outputr   r   r   r*   	  s   

zBertPooler.forwardr+   r   r   r"   r   r         r   c                       r   )BertPredictionHeadTransformc                    sR   t t|   t|j|j| _t|jt	rt
|j n|j| _t|jdd| _d S r   )r   r   r   r   rm   r    r   r   r   r   r   transform_act_fnr   r:   r>   r"   r   r   r     s   
z$BertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   r:   r   r   r   r   r*     s   


z#BertPredictionHeadTransform.forwardr+   r   r   r"   r   r     s    r   c                       r   )BertLMPredictionHeadc                    sZ   t t|   t|| _tj|d|ddd| _|| j_	t
t|d| _d S )Nr@   r   F)r   )r   r   r   r   	transformr   rm   rF   decoderr   r   r
   r   r   r   r?   Zbert_model_embedding_weightsr"   r   r   r   $  s   

zBertLMPredictionHead.__init__c                 C   s   |  |}| || j }|S r   )r   r   r   r   r   r   r   r*   2  s   
zBertLMPredictionHead.forwardr+   r   r   r"   r   r   "  s    r   c                       r   )BertOnlyMLMHeadc                    s   t t|   t||| _d S r   )r   r   r   r   predictionsr   r"   r   r   r   :  s   
zBertOnlyMLMHead.__init__c                 C      |  |}|S r   )r   )r   sequence_outputprediction_scoresr   r   r   r*   ?     
zBertOnlyMLMHead.forwardr+   r   r   r"   r   r   8  r   r   c                       r   )BertOnlyNSPHeadc                    s"   t t|   t|jd| _d S Nr%   )r   r   r   r   rm   r    seq_relationshipr>   r"   r   r   r   F  s   zBertOnlyNSPHead.__init__c                 C   r   r   )r   )r   r   seq_relationship_scorer   r   r   r*   J  r   zBertOnlyNSPHead.forwardr+   r   r   r"   r   r   D  s    r   c                       r   )BertPreTrainingHeadsc                    s.   t t|   t||| _t|jd| _d S r   )	r   r   r   r   r   r   rm   r    r   r   r"   r   r   r   Q  s
   zBertPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r   )r   r   )r   r   r   r   r   r   r   r   r*   W  s   

zBertPreTrainingHeads.forwardr+   r   r   r"   r   r   O  r   r   c                       s:   e Zd ZdZ fddZdd Ze		d	ddZ  ZS )
PreTrainedBertModelz An abstract class to handle weights initialization and
        a simple interface for downloading and loading pretrained models.
    c                    s:   t t|   t|tstd| jj| jj|| _	d S )NzParameter config in `{}(config)` should be an instance of class `SpaceTCnConfig`. To create a model from a Google pretrained model use `model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`)
r   r   r   r   r   ri   formatr#   r,   r?   )r   r?   inputskwargsr"   r   r   r   b  s   

zPreTrainedBertModel.__init__c                 C   s|   t |tjtjfr|jjjd| jjd nt |t	r'|j
j  |jjd t |tjr:|j
dur<|j
j  dS dS dS )z! Initialize the weights.
        g        )r&   Zstdr	   N)r   r   rm   r4   r   dataZnormal_r?   Zinitializer_ranger   r   Zzero_Zfill_)r   moduler   r   r   init_bert_weightsl  s   
z%PreTrainedBertModel.init_bert_weightsNc                    sZ  |}d}t j|r|}n*t }td|| t	|d}	|	
| W d   n1 s0w   Y  |}t j|t}
t|
}td| | |g|R i |}du rft j|t}t|g }g } D ]$}d}d|v r||dd}d|v r|dd}|r|| || qnt||D ]\}}||< qg g g  td	d dur_d fdd	|t|drd
ndd tdkrtd|jj t  tddd td|jj t  tdkr#td|jj t  tddd td|jj t  |r+t | |S )a  
        Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionary (collections.OrderedDict object)
                to use instead of Google pre-trained models
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        Nz)extracting archive file {} to temp dir {}zr:gzzModel config {}gammar   betar   	_metadata c              	      sh   d u ri n	 |d d i }| ||d  | j D ]\}}|d ur1||| d  q d S )Nr$   T.)getZ_load_from_state_dictZ_modulesrQ   )r   prefixZlocal_metadatanamechild
error_msgsloadmetadataZmissing_keys
state_dictZunexpected_keysr   r   r     s   
z1PreTrainedBertModel.from_pretrained.<locals>.loadZbertzbert.)r   r   z7Weights of {} not initialized from pretrained model: {}z
**********zWARNING missing weightsz0Weights from pretrained model not used in {}: {}zWARNING unexpected weights)r   )!ospathisdirtempfilemkdtemploggerinfor   tarfileopen
extractalljoinCONFIG_NAMEr   Zfrom_json_fileWEIGHTS_NAMEr
   r   keysreplacer   zippopgetattrr   r   hasattrlenr#   r,   printshutilrmtree)clsZpretrained_model_namer   	cache_dirr   r   Zresolved_archive_filetempdirZserialization_dirarchiveconfig_filer?   modelZweights_pathZold_keysZnew_keysro   Znew_keyold_keyr   r   r   from_pretrainedz  s   





z#PreTrainedBertModel.from_pretrained)NN)	r,   r-   r.   re   r   r   classmethodr  r/   r   r   r"   r   r   ]  s    
r   c                       sR   e Zd ZdZd	 fdd	Z																			d
ddZ  ZS )SpaceTCnModelaK  SpaceTCnModel model ("Bidirectional Embedding Representations from a Transformer pretrained on STAR-T-CN").

    Params:
        config: a SpaceTCnConfig class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output
            as described below. Default: `True`.

    Outputs: Tuple of (encoded_layers, pooled_output)
        `encoded_layers`: controlled by `output_all_encoded_layers` argument:
            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
                to the last attention block of shape [batch_size, sequence_length, hidden_size],
        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
            classifier pretrained on top of the hidden state associated to the first character of the
            input (`CLF`) to train on the Next-Sentence task (see BERT's paper).

    Example:
        >>> # Already been converted into WordPiece token ids
        >>> input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
        >>> input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
        >>> token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])

        >>> config = modeling.SpaceTCnConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        >>>     num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)

        >>> model = modeling.SpaceTCnModel(config=config)
        >>> all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
    r   c                    sB   t t| | t|| _t||d| _t|| _| 	| j
 d S r   )r   r  r   r0   rc   r   encoderr   poolerapplyr   r   r"   r   r   r     s   

zSpaceTCnModel.__init__NTc                 C   s   |d u r	t |}|d u rt |}|dd}d| d }| |||||||	|
||||||||}| j|||||d}|d }| |}|sM|d }||fS )Nr@   r%   r	   g     )r   r   r   r$   )r
   Z	ones_likerI   rH   rc   r  r  )r   rR   rS   Ztoken_order_idsrT   r   rU   rV   rW   Ztype_idsrX   rY   rZ   r[   r\   r]   r^   r_   r`   r   r   r   Zextended_attention_maskZembedding_outputZencoded_layersr   r   r   r   r   r*     s0   




zSpaceTCnModel.forwardr   )NNNNNNNNNNNNNNNNNNTrd   r   r   r"   r   r    s.    +r  c                       r   )Seq2SQLc                    s   t t|   || _|| _|| _|| _|
| _|| _|| _	|| _
|| _|	| _t||	| _t||	| _t||| _t||	| _t|| || | _t||	 ||	 | _t|d| _t||d | _t||d | _t||	d | _d S )Nrt   r@   )r   r   r   iShSZlsdrrB   	n_agg_ops
n_cond_opsn_action_opsmax_select_nummax_where_numr   rm   w_sss_modelw_sse_model
s_ht_modelwc_ht_modelselect_agg_model
w_op_model
conn_modelaction_model
slen_model
wlen_model)r   r!  r"  ZlSr#  r%  r$  r&  r'  r(  rB   r"   r   r   r   ^  s2   

zSeq2SQL.__init__c                 C   s
   || _ d S r   )rB   )r   rB   r   r   r   
set_device|  s   
zSeq2SQL.set_devicec           *   	      s"  t |}t |}	t||    }t||    }t||    }t |     g }
g }g }g }g }g }g }g }t|D ]\|d  |
d  |d  fddtj	D }|d j	  fddtj
D }|| || dd t| d D | d g||    }||  fd	dtd
| dD fddt|	t D 7 | qPtj|
tjdj}
tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}tj|tjdj}t|}t|jgj}t|jgj}t|jgj}t|jgj}t|j	jgj}t|j
jgj}t||d jgj}t||	jgj}t|D ]}|| d
|
| ||d d f< || d
|| ||d d f< || d
|| ||d d f< || d
|| ||d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< || d
||d d f ||d d d d f< q|dj|d} |dj|j
d }!|dj|j	d }"|dj|jd }#|djj	 |dj}$|dj|dj	dd}%|dj|dj	dd}&|dj|dj	dd}' |dj|dj
dd}(!|djj
 |dj"})|#|(|)| |%|$|&|'f|!|"ffS )Nr@   r%   rt   c                    s   g | ]} d  | qS )   r   r   i)elemr   r   r         z#Seq2SQL.forward.<locals>.<listcomp>r4  c                    s   g | ]} d  j  | qS )   )r(  r5  )r7  r   r   r   r     s    c                 S   s   g | ]}|qS r   r   r5  r   r   r   r     s    c                    s   g | ]}  | qS r   r   r5  )column_indexibr   r   r     r8  r   c                    s   g | ]} d  qS )r   r   r   )indexr   r   r     s    rC   r$   )#maxrJ   rK   rL   rM   rN   rP   r   r   r(  r'  r  r
   ZtensorrG   torB   r   r!  Zindex_selectr/  Zreshaper1  r2  r0  r&  r.  r%  r,  r}   r)  r*  r+  r-  r$  )*r   Z
wemb_layerZl_nrV   start_indexr:  tokensrY   Zmax_l_nZmax_l_hsZ
conn_indexZ
slen_indexZ
wlen_indexZaction_indexZwhere_op_indexZselect_agg_indexZheader_pos_indexZquery_indexZwoiZsaiZqilistZheader_indexZbSZconn_embZslen_embZwlen_embZ
action_embZwo_embZsa_embZqv_embZht_embr6  Zs_ccoZs_slenZs_wlenZs_actionZ	wo_outputZ	wc_outputZwv_ssZwv_seZ	sc_outputZ	sa_outputr   )r:  r7  r;  r<  r   r   r*     s  

$

 "
    .2


zSeq2SQL.forward)r,   r-   r.   r   r3  r*   r/   r   r   r"   r   r   \  s    r   )5re   
__future__r   r   r   r   r   r   r  r  r   rM   rJ   r
   r   Z.modelscope.models.nlp.space_T_cn.configurationr   Zmodelscope.utils.constantr   Zmodelscope.utils.loggerr   r  ZCONFIGURATIONr  ZTORCH_MODEL_BIN_FILEr  r   r   Z
functionalr   r   Moduler   r0   rf   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r   r   r   r   r   <module>   sT   L:kJ s