o
    *j                     @   s   d dl Z d dlmZ d dlZd dlZd dlmZ d dlm	Z	 d dl
mZmZ d dlmZ dd Zd	d
 Zdd Zdd ZG dd deZ	 G dd deZG dd dZ	 G dd deZG dd deZdS )    N)Counter)ngrams)f1_score)ontologyutils)clean_slot_valuesc                 C   sH   | |kp#| |v p#|| v p#|   d |  d kp#|   d |  d kS )Nr   splitab r   c/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/modelscope/trainers/nlp/space/eval.pysimilar   s
   (r   c           	      C   sr   g }g d}| D ]}d}|D ]	}t ||rd}q|s|| q|D ]}d}|D ]}||v r0d}q(|s6 dS q"dS )N)Ztemperatureweekzest ZquickZreminderZnearFT)r   append)	r   r   Zjunks_aZuseless_constraintiZflgjjunkitemr   r   r   setsub!   s*   

r   c                 C   s&   t | t |} }t| |ot|| S N)setr   r   r   r   r   setsim7   s   r   c                 C   sB   t | } t |}i }dD ]}t|| |d}||d|< q|S )N)micro)Zy_trueZy_predZaveragezf1_{})nparrayr   format)predslabelsresultsZavg_nameZmy_f1_scorer   r   r   DA_evaluate<   s   

r"   c                   @   s   e Zd Zdd Zdd ZdS )
BLEUScorerc                 C   s   d S r   r   )selfr   r   r   __init__L   s   zBLEUScorer.__init__c              
      s  g dg d d}d}g d}|D ]\}}dd |D }dd |D }|D ]}t dD ]T}tt||d }	t|	 }
|  |
7  < i |D ]}tt||d }|	D ]}t|d|| |< qVqItfd	d
|	 D } |  t| 7  < q,ddg}|D ]$}|d dkr nt	t
|t
| }||d k r||d< t
||d< q||d 7 }|t
|7 }q&qd||krdntdt|t|  } fddt dD }tdd
 t||D }|t| }|d S )Nr   r   r   r   r   )      ?r'   r'   r'   c                 S      g | ]}|  qS r   r	   ).0hypr   r   r   
<listcomp>Z       z$BLEUScorer.score.<locals>.<listcomp>c                 S   r(   r   r	   )r)   refr   r   r   r+   [   r,         c                 3   s&    | ]\}}|t | | fV  qd S r   )min)r)   ngcount)
max_countsr   r   	<genexpr>l      $ z#BLEUScorer.score.<locals>.<genexpr>i  gHz>c                    s,   g | ]}t  | t |    qS r   )float)r)   r   )
clip_countr2   p0r   r   r+      s   , c                 s   s&    | ]\}}|r|t | V  qd S r   )mathlog)r)   wZp_nr   r   r   r4      r5   d   )ranger   r   sumvaluesmaxgetdictitemsabslenr9   expr6   fsumzip)r$   Zparallel_corpusrcweightsZhypsrefsr*   r   ZhypcntsZcntr-   Zrefcntsr1   ZclipcntZ	bestmatchdiffbpZp_nssbleur   )r7   r2   r3   r8   r   scoreO   sZ   
 $zBLEUScorer.scoreN)__name__
__module____qualname__r%   rQ   r   r   r   r   r#   H   s    r#   c                   @   sX   e Zd Zdd Zdd ZdddZddd	Z		
	dddZ	
	
	dddZdd Z	dS )MultiWOZEvaluatorc                 K   sz   || _ tj| _| j j| _| j j| _t | _	g | _
tj D ]\}}|D ]}| j
|d |  q#qg d| _|d | _d S )N-phoneaddressZpostcode	referenceiddata_dir)readerr   all_domainsdomainsdataall_datatestZ	test_datar#   bleu_scorerZall_info_slotinformable_slotsrC   r   requestablesdb_dir)r$   r]   kwargsdZs_listrO   r   r   r   r%      s   


zMultiWOZEvaluator.__init__c                 C   8   i }|D ]}|d }||vrg ||< ||  | q|S Ndial_idr   r$   r`   dialsturnrk   r   r   r   	pack_dial      zMultiWOZEvaluator.pack_dialNc                 C   s,   |  |}| j|d|d\}}}}|||fS )NTsame_eval_as_cambridgefout)bleu_metriccontext_to_response_eval)r$   r`   rt   rP   successmatchZreq_offer_countsdial_numr   r   r   validation_metric   s   


z#MultiWOZEvaluator.validation_metricc           	      C   s   g g }}|D ]}|r|d d |vrq| |d  | |d  qdd |D }dd |D }|rN|rNz| jt||}W |S  tyM   d}Y |S w d}|S )	Nrk   .jsonresp_genrespc                 S      g | ]}|gqS r   r   r)   _r   r   r   r+          z1MultiWOZEvaluator.bleu_metric.<locals>.<listcomp>c                 S   r~   r   r   r   r   r   r   r+      r           )r   rc   rQ   rH   	Exception)	r$   r`   eval_dial_listgentruthrowwrap_generated
wrap_truthscr   r   r   ru      s0   
zMultiWOZEvaluator.bleu_metricFc              	   C   sR  |  |}i }| jD ]}d||d < d||d < q
d\}}	}
|D ]n}|r+|d |vr+q || }i }i }d|vrFdt| j d v rF|d }tjD ]}| j| d |rc| j| d }| |||}qI| D ]
}|| d ||< qh| j	||||||d\}}}}|	|7 }	|
|7 }
|d	7 }q |	t
|d
  d }|
t
|d
  d }||||fS )Nr   _total_offerr   r   r   r{   goalrequestablerr   r/   g|=r<   )rp   re   listra   keysr   r^   rA   
_parseGoal_evaluateGeneratedDialoguer6   )r$   r`   r   rs   rt   rn   countsreqry   Z	successesmatchesrk   dialreqsr   domain	true_goalrw   rx   statsZ	succ_rateZ
match_rater   r   r   rv      sF   





z*MultiWOZEvaluator.context_to_response_evalc           #      C   sf  | j }i }	i }
g }g }i }| D ]}g |
|< g |	|< || qt|D ]\}}|dkr/q%|durQ||d |d |d |d |d |d |d	 |d
 d |d	 }| D ]}|r}| jjrmdd |d  D }ndd |d  D }||vr}qYd|v sd|v r|dv r| jjs| jjs|d }n|d }| j	|}|
|r| jjj||| dd}ng }t|
| dkr|r||
|< || ||< n$d}|D ]}||
| vrd} nq|r|r||
|< || ||< nd|
|< |D ]F}|dkr%d|v r$|dv rd|d
 v sd|d
 v sd|d v r|	| d q|	| d qd| d |v r5|	| | qqYq%| D ]0}d|| d  v rKd|
|< |d!v rTd|
|< |d"krk|
| skd#|| d$ vrkd|
|< q<	 g d%g d%g d%g d%g d%g d%g d%d&}d}d}| D ]g}d}|dv r| jjj||| d  dd}t|
| tu rd'|
| v r|d(7 }d(}n,t|
| dkrtt|
| t|@ dkr|d(7 }d(}nd)|
| v r|d(7 }d(}||| d< d(|| d*< q|rt|t|  }n|t| krd+}nd,}|D ]&}|| D ]}||d-   d(7  < ||	| v r7||d.   d(7  < qq|dur|D ]F}d}d} t|| dkr_|d(7 }d(}||| d(< qB|| D ]}||	| v rp| d(7 } qc| t|| kr|d(7 }d(}||| d(< qB|rt|t| }ns|t|krd(}nid}nf|d+kr|D ]F}d}d} t|| dkr|d(7 }d(}||| d(< q|| D ]}||	| v r| d(7 } q| t|| kr|d(7 }d(}||| d(< q|rt|t| }n|t|krd(}nd}|dur-|dkr-|d d/ |||	d0i}!t|!}"||" |d1 ||||fS )2a<  Evaluates the dialogue created by the model.
            First we load the user goal of the dialogue, then for each turn
            generated by the system we look for key-words.
            For the Inform rate we look whether the entity was proposed.
            For the Success rate we look for requestables slotsr   Nturn_numZdspnuseraspnaspn_genr}   r|   pointer)r   Zturn_domainr   r   r   r}   r|   r   c                 S      g | ]}|d d qS r/   r   r   r)   rh   r   r   r   r+         z@MultiWOZEvaluator._evaluateGeneratedDialogue.<locals>.<listcomp>c                 S   r   r   r   r   r   r   r   r+   !  r   Zdspn_gen[value_name]z
[value_id])
restauranthotel
attractiontrainbspn_genbspnT)Zreturn_nameFrZ   z[value_reference])r   r   r   Zbookedok[value_]name
informable)taxipolicehospitalr   r[   r   r   )r   r   r   r   r   r   r   _namer/   z_name]   g      ?r   r   r   rk   )r:   real_requestablesprovided_requestables
)re   r   r   	enumerater]   Zuse_true_domain_for_ctr_evalr
   Zuse_true_curr_bspnZuse_true_bspn_for_ctr_evalZbspan_to_constraint_dictrA   dbZ
queryJsonsrE   typestrr   r6   jsondumpswrite)#r$   Zdialogr   r   r   Zsoft_accrs   rt   re   r   Zvenue_offeredZdomains_in_goalr:   Zbspansr   tro   Zsent_tZdom_predr   Zconstraint_dictZvenuesflagZvenr   r   rx   rw   Z
match_statZgoal_venuesrequestZsuccess_statZdomain_successsampleliner   r   r   r      s  




B

	








z,MultiWOZEvaluator._evaluateGeneratedDialoguec                 C   sb  i ||< i g g d||< d|| v r|dkr=d|| v r%|| d  d d|| v r<d|| d v r<|| d  d n+d|| v rY|| d D ]}|d	v rX|| d  | qId|| v rh|| d  d || d  D ].\}}t| j|||\}}t| d
krddd | j|D 	 }||| d |< qpd|| v r|| d || d< |S )z(Parses user goal into dictionary format.)r   r   bookinginfor   bookr   rZ   Zreqtr[   rW   r/    c                 S   s   g | ]}|j qS r   )text)r)   tokenr   r   r   r+     r   z0MultiWOZEvaluator._parseGoal.<locals>.<listcomp>r   r   )
r   rC   r   rf   rE   r
   joinr]   Znlpstrip)r$   r   r   r   rO   vZs_Zv_r   r   r   r     s:   zMultiWOZEvaluator._parseGoalr   )NFN)FFN)
rR   rS   rT   r%   rp   rz   ru   rv   r   r   r   r   r   r   rU      s    
	

/
 ~rU   c                   @   s`   e Zd Zdd Zdd Zdd Zddd	Z	
	dddZdddZdddZ	dd Z
dd ZdS )GenericEvaluatorc                 C   s   || _ i | _d S r   )r]   Zmetric_dictr$   r]   r   r   r   r%     s   
zGenericEvaluator.__init__c                 C   ri   rj   rl   rm   r   r   r   rp     rq   zGenericEvaluator.pack_dialc                 C   s   t d)Nz"Please specify the evaluator first)
ValueError)r$   r!   r   r   r   run_metrics  s   zGenericEvaluator.run_metricsrP   c           	      C   sn   g g }}|D ]}| | |d  | | |d  qdd |D }dd |D }t t||}|S )Nr|   r}   c                 S   r~   r   r   r   r   r   r   r+   &  r   z0GenericEvaluator.bleu_metric.<locals>.<listcomp>c                 S   r~   r   r   r   r   r   r   r+   '  r   )r   cleanr#   rQ   rH   )	r$   r`   r   r   r   r   r   r   r   r   r   r   ru      s   
zGenericEvaluator.bleu_metricFTc                 C   s\   i }| j D ]}d||< q| D ]\}}|r|dkrq|r'|dkr'|| jvr'q|||< q|S )  
        Normalize belief span, e.g. delete repeated words
        :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
        :param intersection: if true, only keeps the words that appear in th ontology
                                        we set intersection=True as in previous works
        :returns: normalized constraint dict
                      e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
         Zdontcare)rd   rC   entities_flat)r$   
constraintignore_dontcareintersection
normalizedrO   r   r   r   r   _normalize_constraint+  s   


z&GenericEvaluator._normalize_constraintc           
      C   s   | d}i }t|D ]>\}}| }t }|  D ]&}	|r;| jj| dkr0d|	v r/||	 q|	| jv r:||	 q||	 q||| jj| < q|S )N|av[value)r
   r   r   r   r]   Z	act_orderaddrequestable_slots)
r$   r   r   Z	aspn_listr   r   r   seqZword_setr;   r   r   r   _normalize_actD  s"   



zGenericEvaluator._normalize_actc                 C   s  d\}}}}di d}}}	| j D ]}
d||
< q|D ]}|r-| |d }| |d }n| j|d dd}| j|d dd}d|d	 voJd
|d	 v}|rx| D ]\}}||| v r`|d7 }qQ|d7 }qQ| D ]\}}||| vrw|d7 }qi|r|r|	d7 }	| j D ]}
||
 ||
 kr||
  d7  < q||kr|d7 }|dr|dr|d |d kr|d7 }q||| d  ||| d  }}d| | || d  }||	 }||	 }|D ]
}
||
  |	  < q||||||fS )Nr&   r   :0yE>r   r   F)r   thankr   byer/   Zdb_genZdb_matchr   )rd   r   rC   rA   )r$   r`   	normalizetpfpfnZ
db_correctZ	goal_accrZ	slot_accrtotalrO   r   r   r   validslotvalue	precisionrecallf1r   r   r   tracker_metricW  sV   




"zGenericEvaluator.tracker_metricc                 C   sl  |  |}d\}}}|D ]}t t }}|| }	t|	D ]V\}
}| |d  }| |d  }|D ]}d|v rR|drR|dkrR||dd d	d  q5|D ]}d|v rr|drr|dkrr||dd d	d  qUq|D ]}||v r|d7 }qv|d7 }qv|D ]
}||vr|d7 }qq||| d
  ||| d
  }}d| | || d
  }|||fS )Nr   r|   r}   r   r   r   r/   r   r   r   r   )rp   r   r   r   r
   endswithr   )r$   r`   rn   r   r   r   rk   Z	truth_reqZgen_reqr   r   ro   Zresp_gen_tokenZ
resp_tokenr;   r   r   r   r   r   r   r   request_metric  sF   



"
zGenericEvaluator.request_metricc                 C   s  ddddddddd}}}| j D ]}d\||< ||< ||< d\|d| < |d| < |d| < q|D ]}| |d }| |d }d|d voPd	|d v}	|	r|d
 D ]3}
|
|d
 v rv|d  d7  < ||
ru||
  d7  < qX|d  d7  < ||
r||
  d7  < qX|d
 D ]}
|
|d
 vr|d  d7  < ||
r||
  d7  < qd|vrq5|d D ]3}||d v r|d  d7  < ||r||  d7  < q|d  d7  < ||r||  d7  < q|d D ]}||d vr|d  d7  < ||r||  d7  < qq5i }| D ]5\}}|| || ||  d  || || ||  d  }}d| | || d  }|||g||< q|S )Nr   )all_sall_vr   
[value_%s]r   r   r   r   r   r   r   r/   asr   r   r   )r   r   rA   rC   )r$   r`   r   r   r   rO   r   r   r   r   r   r   resultkr   r   r   r   r   r   r   
act_metric  s~   







 zGenericEvaluator.act_metricN)rP   FT)F)T)rR   rS   rT   r%   rp   r   ru   r   r   r   r   r   r   r   r   r   r     s    	



-r   c                       sD   e Zd Z fddZdd Zdd Zdd Zd	d
 Zdd Z  Z	S )CamRestEvaluatorc                    >   t  | | | jj\| _| _| jjj| _| jjj	| _	d S r   
superr%   get_entitiesr]   Zontology_pathr   entitiy_to_slot_dictZotlgrd   r   r   	__class__r   r   r%        zCamRestEvaluator.__init__c                 C   s   i }|  |}| |\}}}}}}	| |}
| |\}}}||d< |
|d< ||d< ||d< ||d< |||f|d< |	|d< |S )NrP   rx   req_f1
joint_goal	slot_accuslot-p/r/f1db_accru   r   match_metricr   r$   r!   ZmetricsrP   prI   r   Zgoal_accZslot_accr  rx   r  Zreq_pZreq_rr   r   r   r     s   

zCamRestEvaluator.run_metricsc                 C   sd   g }i }t t|dd  }|d D ]}||d |  |d | D ]}|||< q&q||fS )Nzutf-8)encodingr   )r   loadsopenreadlowerextend)r$   entity_pathr   r   Zraw_entitiesrO   r   r   r   r   r     s   
zCamRestEvaluator.get_entitiesc                 C   "   |s|sdS |r
|sdS t ||S NTFr   r$   
truth_consgen_consr   r   r   constraint_same  
   
z CamRestEvaluator.constraint_samec                 C   s   |  |}d\}}|D ]Y}|| }ddddd }}t|D ]"\}	}
d|
d v r1| j|
d dd}d|
d	 v r@| j|
d
 dd}q|sN| j|d d dd}t| g dkrd||kr`|d7 }|d7 }q|| S )Nr   r   r   )123r   r|   r   Tr   r}   r   r   )r   r   r   r/   )rp   r   r   r   r?   r$   r`   rn   rx   r   rk   r   r  r  r   ro   r   r   r   r
    s2   
zCamRestEvaluator.match_metricc                 C   r   | | jj dd}| d| jj d}| jj d| d| jj }| j D ]\}}t||d| }q)|S Nr   r   r   replacer]   Zsos_r_tokenZeos_r_tokenr   rC   r   Zclean_replacer$   r}   r   r   r   r   r   r   .  s   zCamRestEvaluator.clean)
rR   rS   rT   r%   r   r   r  r
  r   __classcell__r   r   r  r   r     s    r   c                       sR   e Zd Z fddZdd Z		dddZd	d
 Zdd Zdd Zdd Z	  Z
S )KvretEvaluatorc                    r   r   r   r   r  r   r   r%   <  r  zKvretEvaluator.__init__c                 C   s   i }|  |}| j|dd\}}}}}}	| |}
| |\}}}||d< |
|d< ||d< ||d< ||d< |||f|d< |	|d	< |S )
NT)r   rP   rx   r  r  r  r  r  r	  r  r   r   r   r   C  s   

zKvretEvaluator.run_metricsFTc           	      C   s|   g d}i }| j D ]}d||< q	| D ]'\}}|D ]}d||d }q|r0|| jvr0q|| j v r:|||< q	 q|S )r   )ZgoodZgreatZquickestZshortestZrouter   ZfastestnearestnextZclosestZwayZmileZactivityr   Zappointmentr   r   )rd   rC   r   r%  r
   r   )	r$   r   r   r   r   r   rO   r   r   r   r   r   r   U  s   



z$KvretEvaluator._normalize_constraintc                 C   s4   g }i }| j j}|D ]}||vr|| q
||fS r   )r]   Zentity_dictr   )r$   r  r   r   rO   r   r   r   r   w  s   
zKvretEvaluator.get_entitiesc                 C   r  r  r  r  r   r   r   r    r  zKvretEvaluator.constraint_samec                 C   s  |  |}d\}}|D ]v}|| }ddddddddddddd }}t|D ]"\}	}
d|
d v r9| j|
d dd}d|
d	 v rH| j|
d
 dd}q&|sV| j|d d dd}t| dgd krdd | D }dd | D }| ||r}|d7 }|d7 }q|| S )Nr  r   )r  r  r  45678910Z11r   r|   r   Tr   r}   r   r      c                 S      g | ]}|r|qS r   r   r)   xr   r   r   r+     r,   z/KvretEvaluator.match_metric.<locals>.<listcomp>c                 S   r3  r   r   r4  r   r   r   r+     r,   r/   )rp   r   r   r   r?   r  r!  r   r   r   r
    sP   
zKvretEvaluator.match_metricc                 C   r"  r#  r$  r&  r   r   r   r     s   zKvretEvaluator.cleanr   )rR   rS   rT   r%   r   r   r   r  r
  r   r'  r   r   r  r   r(  :  s    
"
)r(  )r9   collectionsr   r   numpyr   Z	nltk.utilr   Zsklearn.metricsr   Zmodelscope.utils.nlp.spacer   r   Z(modelscope.utils.nlp.space.clean_datasetr   r   r   r   r"   objectr#   rU   r   r   r(  r   r   r   r   <module>   s.   ?    VQ