# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
from PIL import Image

from ....utils import logging
from ....utils.deps import class_requires_deps, is_dep_available
from ....utils.fonts import (
    ARABIC_FONT,
    CYRILLIC_FONT,
    DEVANAGARI_FONT,
    EL_FONT,
    KANNADA_FONT,
    KOREAN_FONT,
    LATIN_FONT,
    SIMFANG_FONT,
    TAMIL_FONT,
    TELUGU_FONT,
    TH_FONT,
)
from ....utils.func_register import FuncRegister
from ...common.batch_sampler import ImageBatchSampler
from ...common.reader import ReadImage
from ..predictors import RunnerPredictor, TransformersPredictor
from .processors import (
    CTCLabelDecode,
    OCRReisizeNormImg,
    ToBatch,
    validate_text_rec_image_array,
)
from .result import TextRecResult

if is_dep_available("python-bidi"):
    from bidi.algorithm import get_display


TEXT_REC_TRANSFORMERS_MODELS = [
    "PP-OCRv5_server_rec",
    "PP-OCRv5_mobile_rec",
    "eslav_PP-OCRv5_mobile_rec",
    "korean_PP-OCRv5_mobile_rec",
    "latin_PP-OCRv5_mobile_rec",
    "en_PP-OCRv5_mobile_rec",
    "th_PP-OCRv5_mobile_rec",
    "el_PP-OCRv5_mobile_rec",
    "arabic_PP-OCRv5_mobile_rec",
    "te_PP-OCRv5_mobile_rec",
    "ta_PP-OCRv5_mobile_rec",
    "devanagari_PP-OCRv5_mobile_rec",
    "cyrillic_PP-OCRv5_mobile_rec",
]


def get_text_rec_vis_font(model_name):
    if model_name.startswith(("PP-OCR", "en_PP-OCR")):
        return SIMFANG_FONT

    if model_name in (
        "latin_PP-OCRv3_mobile_rec",
        "latin_PP-OCRv5_mobile_rec",
    ):
        return LATIN_FONT

    if model_name in (
        "cyrillic_PP-OCRv3_mobile_rec",
        "cyrillic_PP-OCRv5_mobile_rec",
        "eslav_PP-OCRv5_mobile_rec",
    ):
        return CYRILLIC_FONT

    if model_name in (
        "korean_PP-OCRv3_mobile_rec",
        "korean_PP-OCRv5_mobile_rec",
    ):
        return KOREAN_FONT

    if model_name == "th_PP-OCRv5_mobile_rec":
        return TH_FONT

    if model_name == "el_PP-OCRv5_mobile_rec":
        return EL_FONT

    if model_name in (
        "arabic_PP-OCRv3_mobile_rec",
        "arabic_PP-OCRv5_mobile_rec",
    ):
        return ARABIC_FONT

    if model_name == "ka_PP-OCRv3_mobile_rec":
        return KANNADA_FONT

    if model_name in ("te_PP-OCRv3_mobile_rec", "te_PP-OCRv5_mobile_rec"):
        return TELUGU_FONT

    if model_name in ("ta_PP-OCRv3_mobile_rec", "ta_PP-OCRv5_mobile_rec"):
        return TAMIL_FONT

    if model_name in (
        "devanagari_PP-OCRv3_mobile_rec",
        "devanagari_PP-OCRv5_mobile_rec",
    ):
        return DEVANAGARI_FONT


@class_requires_deps("python-bidi")
class TextRecRunnerPredictor(RunnerPredictor):

    _FUNC_MAP = {}
    register = FuncRegister(_FUNC_MAP)

    def __init__(self, *args, input_shape=None, return_word_box=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.input_shape = input_shape
        self.return_word_box = return_word_box
        self.vis_font = self.get_vis_font()
        self.pre_tfs, self.post_op = self._build()

    def _build_batch_sampler(self):
        return ImageBatchSampler()

    def _get_result_class(self):
        return TextRecResult

    def _build(self):
        pre_tfs = {"Read": ReadImage(format="RGB")}
        for cfg in self.config["PreProcess"]["transform_ops"]:
            tf_key = list(cfg.keys())[0]
            assert tf_key in self._FUNC_MAP
            func = self._FUNC_MAP[tf_key]
            args = cfg.get(tf_key, {})
            name, op = func(self, **args) if args else func(self)
            if op:
                pre_tfs[name] = op
        pre_tfs["ToBatch"] = ToBatch()

        post_op = self.build_postprocess(**self.config["PostProcess"])
        return pre_tfs, post_op

    def process(self, batch_data, return_word_box=False):
        batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data.instances)
        for i, img in enumerate(batch_raw_imgs):
            validate_text_rec_image_array(img, index=i)
        width_list = []
        for img in batch_raw_imgs:
            width_list.append(img.shape[1] / float(img.shape[0]))
        indices = np.argsort(np.array(width_list))
        batch_imgs = self.pre_tfs["ReisizeNorm"](imgs=batch_raw_imgs)
        x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
        batch_preds = self.runner(x=x)
        batch_num = self.batch_sampler.batch_size
        img_num = len(batch_raw_imgs)
        rec_image_shape = next(
            op["RecResizeImg"]["image_shape"]
            for op in self.config["PreProcess"]["transform_ops"]
            if "RecResizeImg" in op
        )
        imgC, imgH, imgW = rec_image_shape[:3]
        max_wh_ratio = imgW / imgH
        end_img_no = min(img_num, batch_num)
        wh_ratio_list = []
        for ino in range(0, end_img_no):
            h, w = batch_raw_imgs[indices[ino]].shape[0:2]
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
            wh_ratio_list.append(wh_ratio)
        texts, scores = self.post_op(
            batch_preds,
            return_word_box=return_word_box or self.return_word_box,
            wh_ratio_list=wh_ratio_list,
            max_wh_ratio=max_wh_ratio,
        )
        if self.model_name in (
            "arabic_PP-OCRv3_mobile_rec",
            "arabic_PP-OCRv5_mobile_rec",
        ):
            texts = [
                (get_display(s[0]), s[1]) if isinstance(s, tuple) else get_display(s)
                for s in texts
            ]
        return {
            "input_path": batch_data.input_paths,
            "page_index": batch_data.page_indexes,
            "input_img": batch_raw_imgs,
            "rec_text": texts,
            "rec_score": scores,
            "vis_font": [self.vis_font] * len(batch_raw_imgs),
        }

    @register("DecodeImage")
    def build_readimg(self, channel_first, img_mode):
        assert channel_first == False
        return "Read", ReadImage(format=img_mode)

    @register("RecResizeImg")
    def build_resize(self, image_shape, **kwargs):
        return "ReisizeNorm", OCRReisizeNormImg(
            rec_image_shape=image_shape, input_shape=self.input_shape
        )

    def build_postprocess(self, **kwargs):
        if kwargs.get("name") == "CTCLabelDecode":
            return CTCLabelDecode(
                character_list=kwargs.get("character_dict"),
            )
        else:
            raise Exception()

    @register("MultiLabelEncode")
    def foo(self, *args, **kwargs):
        return None, None

    @register("KeepKeys")
    def foo(self, *args, **kwargs):
        return None, None

    def get_vis_font(self):
        return get_text_rec_vis_font(self.model_name)


class TextRecTransformersPredictor(TransformersPredictor):
    """Text recognition predictor backed by Hugging Face transformers."""

    def __init__(self, *args, return_word_box: bool = False, **kwargs):
        super().__init__(*args, **kwargs)
        self.return_word_box = return_word_box
        self.vis_font = get_text_rec_vis_font(self.model_name)
        self.read_op = ReadImage(format="RGB")
        self.image_processor, self.infer = self._build()
        self.post_op = self._build_postprocess()

    def _build_batch_sampler(self):
        return ImageBatchSampler()

    def _get_result_class(self):
        return TextRecResult

    def _build(self):
        from transformers import AutoImageProcessor, AutoModelForTextRecognition

        image_processor = self._load_pretrained_processor(AutoImageProcessor)
        model = self._load_pretrained_model(AutoModelForTextRecognition)
        return image_processor, model

    def _build_postprocess(self):
        character_list = getattr(self.image_processor, "character_list", None)
        if not character_list:
            character_list = self.model_config.get("PostProcess", {}).get(
                "character_dict"
            )
        if not character_list:
            raise RuntimeError(
                f"{type(self.image_processor).__name__} does not provide "
                "the character dictionary required for text recognition decoding."
            )
        if character_list[0] == "blank":
            character_list = character_list[1:]
        return CTCLabelDecode(character_list=character_list)

    def _get_rec_image_shape(self):
        size = getattr(self.image_processor, "size", {}) or {}
        pad_size = getattr(self.image_processor, "pad_size", {}) or size
        img_h = int(pad_size.get("height", size.get("height", 48)))
        img_w = int(pad_size.get("width", size.get("width", 320)))
        return 3, img_h, img_w

    def process(self, batch_data, return_word_box: Optional[bool] = None):
        if return_word_box:
            logging.warning("transformers engine doesn't support `return_word_box`")

        batch_raw_imgs = self.read_op(imgs=batch_data.instances)
        for i, img in enumerate(batch_raw_imgs):
            validate_text_rec_image_array(img, index=i)
        images = [Image.fromarray(img) for img in batch_raw_imgs]

        model_inputs = self.preprocess_images(images=images)
        outputs = self.forward(model_inputs)
        texts, scores = self.postprocess(outputs)

        return {
            "input_path": batch_data.input_paths,
            "page_index": batch_data.page_indexes,
            "input_img": batch_raw_imgs,
            "rec_text": texts,
            "rec_score": scores,
            "vis_font": [self.vis_font] * len(batch_raw_imgs),
        }

    def postprocess(self, outputs, **kwargs):
        results = self.image_processor.post_process_text_recognition(outputs)
        texts = [r["text"] for r in results]
        scores = [r["score"] for r in results]

        return texts, scores
