o
    0jki                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlmZ d dlmZm	Z	 d dl
ZddlmZ ddlmZ ddlmZ ddlmZ d	d
lmZ d	dlmZ ddlmZmZmZ ddlmZ ddlmZm Z m!Z! ddl"m#Z# ddlm$Z$m%Z% G dd deZ&G dd deZ'G dd deZ(dS )    N)Path)ListOptional   )logging)require_genai_client_plugin)TemporaryDeviceChanger)import_paddle   )DocVLMBatchSampler)is_bfloat16_available   )GenAIClientPredictorLocalModelPredictorTransformersPredictor)get_model_paths   )$PADDLEOCR_VL_GENAI_CLIENT_BATCH_SIZEPADDLEOCR_VL_LOCAL_BATCH_SIZEPADDLEOCR_VL_MAX_NEW_TOKENSDocVLMResult)format_doc_vlm_result_dictis_in_groupc                       s   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z								d de	e
 dee dee dee dee dee dee dee dee fddZdd Zdd Zdd Zdd Z  ZS )!DocVLMLocalPredictorzBDocVLM predictor for local model inference (Paddle dynamic graph).c                    s   t  j|i | |dd}|dkr|  | j_n|| j_t| jr'd| _nd| _| j	d
i |\| _
| _t| jdrW| jjtkrYtdt| j dt d t| j_d	S d	S d	S )zInitializes DocVLMPredictor.
        Args:
            *args: Arbitrary positional arguments passed to the superclass.
            **kwargs: Arbitrary keyword arguments passed to the superclass.
        
batch_sizeZbfloat16Zfloat32PaddleOCR-VLzCurrently, the z) local model only supports batch size of z!. The batch size will be updated.N )super__init__get_determine_batch_sizebatch_samplerr   r   devicedtype_buildinfer	processorr   
model_namer   r   warningrepr)selfargskwargsbs	__class__r   k/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/paddlex/inference/models/doc_vlm/predictor.pyr    1   s$   

zDocVLMLocalPredictor.__init__c                 C   
   t | jS )zBuilds and returns an DocVLMBatchSampler instance.

        Returns:
            DocVLMBatchSampler: An instance of DocVLMBatchSampler.
        r   r)   r,   r   r   r2   _build_batch_samplerO   s   
z)DocVLMLocalPredictor._build_batch_samplerc                 C      t S )zlReturns the result class, DocVLMResult.

        Returns:
            type: The DocVLMResult class.
        r   r5   r   r   r2   _get_result_classW   s   z&DocVLMLocalPredictor._get_result_classc           	      K   s2  ddl m}m}m}m} |  }t| jdrF|ddr!t	
d t| j |j| j| jd}W d   ||fS 1 s=w   Y  ||fS t| jd	r|ddrWt	
d
 t| j9 t| j}d|v ru|j| j| j|jjddd}n|j| j| j|jjd}W d   ||fS W d   ||fS 1 sw   Y  ||fS t| jdr|ddrt	
d
 t| j |j| j| jd}W d   ||fS 1 sw   Y  ||fS t| jdr|ddrt	
d t| j |j| j| jdd}W d   ||fS 1 sw   Y  ||fS td| j d)a  Build the model, and correspounding processor on the configuration.

        Returns:
            model: An instance of Paddle model, could be either a dynamic model or a static model.
            processor: The correspounding processor for the model.
        r   )#PaddleOCRVLForConditionalGenerationPPChart2TableInferencePPDocBee2InferencePPDocBeeInference	PP-DocBeeZuse_hpipFz>The PP-DocBee series does not support `use_hpip=True` for now.)r%   NPP-Chart2TablezCThe PP-Chart2Table series does not support `use_hpip=True` for now.ZsafetensorsT)r%   pad_token_idZuse_safetensorsconvert_from_hf)r%   r?   
PP-DocBee2r   zAThe PaddleOCR-VL series does not support `use_hpip=True` for now.)r%   r@   zModel z is not supported.)Zmodelingr9   r:   r;   r<   build_processorr   r)   r!   warningswarnr   r$   from_pretrained	model_dirr%   r   	tokenizerZeos_token_idNotImplementedError)	r,   r.   r9   r:   r;   r<   r(   modelZ
model_pathr   r   r2   r&   _   s   
33

*
**

		zDocVLMLocalPredictor._buildc                 C   s>   t | jdrt}td| j d| d |S td| j )Nr   zThe batch size of z is determined to be .z#Could not determine batch size for )r   r)   r   r   debugRuntimeError)r,   r   r   r   r2   r"      s   z*DocVLMLocalPredictor._determine_batch_sizeNdatamax_new_tokensskip_special_tokensrepetition_penaltytemperaturetop_p
min_pixels
max_pixels	use_cachec
                 K   s|  t dd |D sJ t|}t| jdr | jj|||d}n&| j|}|dur6tdt| j d |durFtdt| j d | 	|}i }|durV||d	< n
t| jdr`t
|d	< |duritd
 |durrtd |dur{td |	dur|	|d< t| j | jj|fi |}W d   n1 sw   Y  i }|dur||d< | jj|fi |}| ||}|S )a  
        Process a batch of data through the preprocessing, inference, and postprocessing.

        Args:
            data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).

        Returns:
            dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
        c                 s       | ]}t |tV  qd S N
isinstancedict.0ir   r   r2   	<genexpr>       z/DocVLMLocalPredictor.process.<locals>.<genexpr>r   )rS   rT   Nz/`min_pixels` is currently not supported by the z model and will be ignored.z/`max_pixels` is currently not supported by the rN   zW`repetition_penalty` is currently not supported by the local model and will be ignored.zP`temperature` is currently not supported by the local model and will be ignored.zJ`top_p` is currently not supported by the local model and will be ignored.rU   rO   )allcopyr   r)   r(   
preprocessrC   rD   r+   _switch_inputs_to_devicer   r   r$   r'   generatepostprocess_format_result_dict)r,   rM   rN   rO   rP   rQ   rR   rS   rT   rU   r.   src_datagenerate_kwargspredsZpostprocess_kwargsZresult_dictr   r   r2   process   s`   


zDocVLMLocalPredictor.processc                 K   s  ddl m}m}m}m} ddlm} ddlm}m	}m
}	m}
m}m}m}m} t| jdr;| }|| j}|||dS t| jdr|d	}d }t| jd
 }| rodd l}t|}||}W d    n1 sjw   Y  |tt| jd |d}|	||| jdS t| jdr| }|| j}|
||dS t| jdr|| j}tt| jd}|j| j|d}t| jd}||jdd|_|||dS t)Nr   )LlamaTokenizerMIXQwen2_5_TokenizerMIXQwen2TokenizerQWenTokenizer)ChatTemplater   )GOTImageProcessorPaddleOCRVLProcessorPPChart2TableProcessorPPDocBee2ProcessorPPDocBeeProcessorQwen2_5_VLImageProcessorQwen2VLImageProcessorSiglipImageProcessorr=   )image_processorrG   r>   i   zadded_tokens.jsonr   zqwen.tiktoken)
vocab_fileextra_special_tokens)rx   rG   r%   rA   r   ztokenizer.model)ry   zchat_template.jinjazutf-8)encoding)Zcommon.tokenizerrk   rl   rm   rn   Z common.tokenizer.tokenizer_utilsro   Z
processorsrp   rq   rr   rs   rt   ru   rv   rw   r   r)   rE   rF   r   existsjsonopenloadstrr%   Z_compile_jinja_template	read_textZchat_templaterH   )r,   r.   rk   rl   rm   rn   ro   rp   rq   rr   rs   rt   ru   rv   rw   rx   rG   rz   Zadded_tokens_filer}   fry   Zchat_template_filer   r   r2   rB     s\   (

z$DocVLMLocalPredictor.build_processorc                 C   s   t ||ddS )NTZadd_input_path)r   )r,   Zmodel_predsrg   r   r   r2   rf   H  s   z(DocVLMLocalPredictor._format_result_dictc           
         s   ddl }ddlm} |du rdS d| v rdS ||\}}tjdd}|du r;t| }d	dd	 t
|D }|d  D ]}| sOtd
| qBt|t krctd| d| d fdd	|D }	|d d	|	 S )z0infer the forward device for dynamic graph modelr   Nr   )parse_devicecpuZCUDA_VISIBLE_DEVICES,c                 S   s   g | ]}t |qS r   )r   r[   r   r   r2   
<listcomp>Z      zFDocVLMLocalPredictor._infer_dynamic_forward_device.<locals>.<listcomp>z?CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: zRequired gpu ids z0 even larger than the number of visible devices rJ   c                    s   g | ]} | qS r   r   )r\   idxZenv_device_idsr   r2   r   g  r   :)GPUtilutils.devicer   lowerosenvironr!   lenZgetGPUsjoinrangesplitisdigit
ValueErrormax)
r,   r$   r   r   Zdevice_typeZ
device_idsZcuda_visible_devicesZenv_gpu_numZenv_device_idZrst_global_gpu_idsr   r   r2   _infer_dynamic_forward_deviceK  s0   
z2DocVLMLocalPredictor._infer_dynamic_forward_devicec                    s.   t  jdu r
 S  fdd D }|S )z(Switch the input to the specified deviceNc                    s:   i | ]}|t  | jrj | jd n | qS ))Zplace)rY   ZTensorZ	to_tensorr$   )r\   k
input_dictZpaddler,   r   r2   
<dictcomp>p  s    zADocVLMLocalPredictor._switch_inputs_to_device.<locals>.<dictcomp>)r	   r$   )r,   r   Zrst_dictr   r   r2   rc   j  s   
z-DocVLMLocalPredictor._switch_inputs_to_deviceNNNNNNNN)__name__
__module____qualname____doc__r    r6   r8   r&   r"   r   rZ   r   intboolfloatrj   rB   rf   r   rc   __classcell__r   r   r0   r2   r   .   sL    L	

PCr   c                       s   e Zd ZdZ fddZdd Zdd Zdd	d
Zdd Zdd Z								dde
e dee dee dee dee dee dee dee fddZdd Z  ZS )DocVLMGenAIClientPredictorz<DocVLM predictor for remote GenAI inference via GenAIClient.c                    sz   | dd }| dd}|d u s|stdt j||d |dd}|dkr1t| jd	r1t}n|dkr7d}|| j_	d S )
Nengine_configr)    z4DocVLMGenAIClientPredictor requires `engine_config`.)r)   r   r   r   r   r   )
popr   r   r    r!   r   r)   r   r#   r   )r,   r-   r.   r   r)   r/   r0   r   r2   r    ~  s   z#DocVLMGenAIClientPredictor.__init__c                 C   r3   rW   r4   r5   r   r   r2   r6        
z/DocVLMGenAIClientPredictor._build_batch_samplerc                 C   r7   rW   r   r5   r   r   r2   r8        z,DocVLMGenAIClientPredictor._get_result_classNc                 k   s    | j |fi |E d H  d S rW   )apply)r,   inputr   r.   r   r   r2   __call__  s   z#DocVLMGenAIClientPredictor.__call__c                 k   s    | |fi |E dH  dS )z.Alias for __call__ for pipeline compatibility.Nr   )r,   r   r.   r   r   r2   predict  s   z"DocVLMGenAIClientPredictor.predictc                 +   s~    ddl m} | |D ]0}t|ts|g}| j|fi |}tt|dg D ]  fdd|	 D }||V  q)qd S )Nr   r   resultc                    s(   i | ]\}}|t |tr|  n|qS r   )rY   list)r\   r   vr   r   r2   r     s    z4DocVLMGenAIClientPredictor.apply.<locals>.<dictcomp>)
r   r   r#   rY   r   rj   r   r   r!   items)r,   r   r.   r   Z	instancespredsingler   r   r2   r     s   

z DocVLMGenAIClientPredictor.applyrM   rN   rO   rP   rQ   rR   rS   rT   c	              
   K   s.   t   | j||||||||d}
t|
|ddS )N)rN   rO   rP   rQ   rR   rS   rT   Tr   )r   _genai_client_processr   )r,   rM   rN   rO   rP   rQ   rR   rS   rT   r.   ri   r   r   r2   rj     s   
z"DocVLMGenAIClientPredictor.processc	              
   C   s  | j }	g }
|	jdkrd}nd}z|D ]}|d }t|trz|ds)|dr,|}nddlm} ||:}|d	}t	
 !}|j||d
 d|  dt| d }W d    n1 sew   Y  W d    n1 stw   Y  nSt|tjrdd l}ddlm} |||j}||}t	
 !}|j||d
 d|  dt| d }W d    n1 sw   Y  n	tdt| |	jdkr|d u rdn||d u rdn|d}nd|d u rdn|i}|d ur||d< |	jdv rd}nd}|d ur
|||< nt| jdrt||< i |d< |d ur/|	jdv r+||d d< ntd|d ur:||d d< |d ure|	jdkrZ|d di |d d< ||d d d< ntt |	j d  |d ur|	jdkr|d di |d d< ||d d d!< ntt |	j d" |	j!d#d$d%|id&d'|d( d)gd*gfd+d,d-|}|
"| qg }|
D ]}|# }|"|j$d j%j& q|W S  t'y   |
D ]}|( s|)  qԂ w ).Nllama-cpp-serverZPNGZJPEGimagezhttp://zhttps://r   )ImageRGB)formatzdata:image/z;base64,asciizNot supported image type: fastdeploy-serverr   )rQ   rR   rQ   rR   )mlx-vlm-serverr   Z
max_tokensZmax_completion_tokensr   Z
extra_body)r   vllm-serverzsglang-serverr   r   rO   zNot supportedrP   r   Zmm_processor_kwargsrS   z does not support `min_pixels`.rT   z does not support `max_pixels`.user	image_urlurl)typer   textqueryr   r   ZrolecontentTiX  )Zreturn_futuretimeout)*Zgenai_clientbackendrY   r   
startswithZPILr   r~   convertioBytesIOsaver   base64	b64encodegetvaluedecodenpZndarraycv2ZcvtColorZCOLOR_BGR2RGBZ	fromarray	TypeErrorr   r   r)   r   r   r!   rC   rD   r+   Zcreate_chat_completionappendr   choicesmessager   	Exceptiondonecancel)r,   rM   rN   rO   rP   rQ   rR   rS   rT   clientZfuturesZimage_formatitemr   r   r   Zimgbufr   r.   Zmax_tokens_namefutureresultsr   r   r   r2   r     s   











	








z0DocVLMGenAIClientPredictor._genai_client_processrW   )NNNNNNN)r   r   r   r   r    r6   r8   r   r   r   r   rZ   r   r   r   r   rj   r   r   r   r   r0   r2   r   {  sB    
	
r   c                       s   e Zd ZdZ fddZdd Zdd Zdd	 Z	
	
	
	
	
	
	
	
ddee	 de
e de
e de
e de
e de
e de
e de
e de
e fddZdd Z  ZS )DocVLMTransformersPredictorz5DocVLM predictor backed by Hugging Face transformers.c                    s:   t  j|i | | jjdkrt| j_|  \| _| _d S )Nr   )r   r    r#   r   r   r&   r(   r'   )r,   r-   r.   r0   r   r2   r    V  s   z$DocVLMTransformersPredictor.__init__c                 C   r3   rW   r4   r5   r   r   r2   r6   \  r   z0DocVLMTransformersPredictor._build_batch_samplerc                 C   r7   rW   r   r5   r   r   r2   r8   _  r   z-DocVLMTransformersPredictor._get_result_classc                 C   s,   ddl m}m} | |}| |}||fS )Nr   )AutoModelForImageTextToTextAutoProcessor)Ztransformersr   r   Z_load_pretrained_processorZ_load_pretrained_model)r,   r   r   r(   rI   r   r   r2   r&   b  s   

z"DocVLMTransformersPredictor._buildNrM   rN   rO   rP   rQ   rR   rS   rT   rU   c
                 K   sx  ddl m} tdd |D sJ t|}g }g }|D ]-}||d }dd|dd|d	d
dgdg}| jj|ddd}|| || qt| j	drvdt
| jjji}|d urc||d d< |d urm||d d< | j|||d}n| j||d}d|d ur|nti}|d ur||d< |d ur||d< |d ur||d< |	d ur|	|d< | ||}| j|||d}t||ddS )Nr   )fetch_imagec                 s   rV   rW   rX   r[   r   r   r2   r^   x  r_   z6DocVLMTransformersPredictor.process.<locals>.<genexpr>r   r   )r   r   r   r   r   r   r   FT)tokenizeZadd_generation_promptr   sizeZshortest_edgeZlongest_edge)imagesr   images_kwargs)r   r   rN   rP   rQ   rR   rU   )model_inputsrO   r   )Zprocessors.commonr   r`   ra   r!   r(   Zapply_chat_templater   r   r)   rZ   rx   r   Zpreprocess_imagesr   rd   re   r   )r,   rM   rN   rO   rP   rQ   rR   rS   rT   rU   r.   r   rg   r   Ztextsr   r   messagespromptr   r   rh   Zgenerated_idsri   r   r   r2   rj   i  s`   
	
z#DocVLMTransformersPredictor.processc                K   s>   |d }dd t ||D }| jj||d u rdn|dd}|S )N	input_idsc                 S   s    g | ]\}}|t |d  qS rW   )r   )r\   r   Z
output_idsr   r   r2   r     s    z;DocVLMTransformersPredictor.postprocess.<locals>.<listcomp>TF)rO   Zclean_up_tokenization_spaces)zipr(   Zbatch_decode)r,   Zoutputsr   rO   r.   Z
prompt_idsZgenerated_ids_trimmedri   r   r   r2   re     s   z'DocVLMTransformersPredictor.postprocessr   )r   r   r   r   r    r6   r8   r&   r   rZ   r   r   r   r   rj   re   r   r   r   r0   r2   r   S  sD    
	

Kr   ))r   ra   r   r   rC   pathlibr   typingr   r   numpyr   utilsr   Z
utils.depsr   r   r   Zutils.import_guardr	   Zcommon.batch_samplerr   Z
utils.miscr   Z
predictorsr   r   r   Zutils.model_pathsr   	constantsr   r   r   r   r   r   r   r   r   r   r   r   r   r2   <module>   s2     O Y