# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from PIL import Image

from ....utils.func_register import FuncRegister
from ...common.batch_sampler import ImageBatchSampler
from ..predictors import RunnerPredictor, TransformersPredictor
from .processors import (
    DetPad,
    DetPostProcess,
    Normalize,
    PadStride,
    ReadImage,
    Resize,
    ToBatch,
    ToCHWImage,
    WarpAffine,
)
from .result import DetResult
from .utils import STATIC_SHAPE_MODEL_LIST

RTDETR_L_MODELS = [
    "RT-DETR-L",
    "RT-DETR-L_wired_table_cell_det",
    "RT-DETR-L_wireless_table_cell_det",
    "PP-DocLayout_plus-L",
    "PP-DocBlockLayout",
]
DET_TRANSFORMERS_MODELS = RTDETR_L_MODELS


class DetRunnerPredictor(RunnerPredictor):
    """Object detection predictor using inference runner."""

    _FUNC_MAP = {}
    register = FuncRegister(_FUNC_MAP)

    def __init__(
        self,
        *args,
        img_size: Optional[Union[int, Tuple[int, int]]] = None,
        threshold: Optional[Union[float, dict]] = None,
        layout_nms: Optional[bool] = None,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
        **kwargs,
    ):
        """Initializes DetPredictor.
        Args:
            *args: Arbitrary positional arguments passed to the superclass.
            img_size (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
            threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
                Defaults to None.
            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
                Defaults to None.
                If it's a single number, then both width and height are used.
                If it's a tuple of two numbers, then they are used separately for width and height respectively.
                If it's None, then no unclipping will be performed.
            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
            **kwargs: Arbitrary keyword arguments passed to the superclass.
        """
        super().__init__(*args, **kwargs)

        if img_size is not None:
            assert (
                self.model_name not in STATIC_SHAPE_MODEL_LIST
            ), f"The model {self.model_name} is not supported set input shape"
            if isinstance(img_size, int):
                img_size = (img_size, img_size)
            elif isinstance(img_size, (tuple, list)):
                assert len(img_size) == 2, f"The length of `img_size` should be 2."
            else:
                raise ValueError(
                    f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
                )

        if layout_unclip_ratio is not None:
            if isinstance(layout_unclip_ratio, float):
                layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
            elif isinstance(layout_unclip_ratio, (tuple, list)):
                assert (
                    len(layout_unclip_ratio) == 2
                ), f"The length of `layout_unclip_ratio` should be 2."
            elif isinstance(layout_unclip_ratio, dict):
                pass
            else:
                raise ValueError(
                    f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict, but got {type(layout_unclip_ratio)}."
                )

        if layout_merge_bboxes_mode is not None:
            if isinstance(layout_merge_bboxes_mode, str):
                assert layout_merge_bboxes_mode in [
                    "union",
                    "large",
                    "small",
                ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}"

        self.img_size = img_size
        self.threshold = threshold
        self.layout_nms = layout_nms
        self.layout_unclip_ratio = layout_unclip_ratio
        self.layout_merge_bboxes_mode = layout_merge_bboxes_mode
        self.pre_ops, self.post_op = self._build()

    def _build_batch_sampler(self):
        return ImageBatchSampler()

    def _get_result_class(self):
        return DetResult

    def _build(self) -> Tuple:
        """Build the preprocessors and postprocessors based on the configuration.

        Returns:
            tuple: A tuple containing the preprocessors and postprocessors.
        """
        # build preprocess ops
        pre_ops = [ReadImage(format="RGB")]
        for cfg in self.config["Preprocess"]:
            tf_key = cfg["type"]
            func = self._FUNC_MAP[tf_key]
            cfg.pop("type")
            args = cfg
            op = func(self, **args) if args else func(self)
            if op:
                pre_ops.append(op)
        pre_ops.append(self.build_to_batch())
        if self.img_size is not None:
            if isinstance(pre_ops[1], Resize):
                pre_ops.pop(1)
            pre_ops.insert(1, self.build_resize(self.img_size, False, 2))

        # build postprocess op
        post_op = self.build_postprocess()

        return pre_ops, post_op

    def _format_output(self, pred: Sequence[Any]) -> List[dict]:
        """
        Transform batch outputs into a list of single image output.

        Args:
            pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
                - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
                  compatible with SOLOv2 output.
                - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
                  compatible with Instance Segmentation output.

        Returns:
            List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
                or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
        """
        box_idx_start = 0
        pred_box = []

        if len(pred) == 4:
            # Adapt to SOLOv2
            pred_class_id = []
            pred_mask = []
            pred_class_id.append([pred[1], pred[2]])
            pred_mask.append(pred[3])
            return [
                {
                    "class_id": np.array(pred_class_id[i]),
                    "masks": np.array(pred_mask[i]),
                }
                for i in range(len(pred_class_id))
            ]

        if len(pred) == 3:
            # Adapt to Instance Segmentation
            pred_mask = []
        for idx in range(len(pred[1])):
            np_boxes_num = pred[1][idx]
            box_idx_end = box_idx_start + np_boxes_num
            np_boxes = pred[0][box_idx_start:box_idx_end]
            pred_box.append(np_boxes)
            if len(pred) == 3:
                np_masks = pred[2][box_idx_start:box_idx_end]
                pred_mask.append(np_masks)
            box_idx_start = box_idx_end

        if len(pred) == 3:
            return [
                {"boxes": np.asarray(pred_box[i]), "masks": np.asarray(pred_mask[i])}
                for i in range(len(pred_box))
            ]
        else:
            return [{"boxes": np.array(res)} for res in pred_box]

    def process(
        self,
        batch_data: List[Any],
        threshold: Optional[Union[float, dict]] = None,
        layout_nms: bool = False,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
    ):
        """
        Process a batch of data through the preprocessing, inference, and postprocessing.

        Args:
            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
            threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
            layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
            layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
            layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.

        Returns:
            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
                for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
        """
        datas = batch_data.instances
        # preprocess
        for pre_op in self.pre_ops[:-1]:
            datas = pre_op(datas)

        # use `ToBatch` format batch inputs
        batch_inputs = self.pre_ops[-1](datas)

        # do infer
        batch_preds = self.runner(batch_inputs)

        # process a batch of predictions into a list of single image result
        preds_list = self._format_output(batch_preds)
        # postprocess
        boxes = self.post_op(
            preds_list,
            datas,
            threshold=threshold if threshold is not None else self.threshold,
            layout_nms=layout_nms or self.layout_nms,
            layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
            layout_merge_bboxes_mode=layout_merge_bboxes_mode
            or self.layout_merge_bboxes_mode,
        )

        return {
            "input_path": batch_data.input_paths,
            "page_index": batch_data.page_indexes,
            "input_img": [data["ori_img"] for data in datas],
            "boxes": boxes,
        }

    @register("Resize")
    def build_resize(self, target_size, keep_ratio=False, interp=2):
        assert target_size
        if isinstance(interp, int):
            interp = {
                0: "NEAREST",
                1: "LINEAR",
                2: "BICUBIC",
                3: "AREA",
                4: "LANCZOS4",
            }[interp]
        op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
        return op

    @register("NormalizeImage")
    def build_normalize(
        self,
        norm_type=None,
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
        is_scale=True,
    ):
        if is_scale:
            scale = 1.0 / 255.0
        else:
            scale = 1
        if not norm_type or norm_type == "none":
            norm_type = "mean_std"
        if norm_type != "mean_std":
            mean = 0
            std = 1
        return Normalize(scale=scale, mean=mean, std=std)

    @register("Permute")
    def build_to_chw(self):
        return ToCHWImage()

    @register("Pad")
    def build_pad(self, fill_value=None, size=None):
        if fill_value is None:
            fill_value = [127.5, 127.5, 127.5]
        if size is None:
            size = [3, 640, 640]
        return DetPad(size=size, fill_value=fill_value)

    @register("PadStride")
    def build_pad_stride(self, stride=32):
        return PadStride(stride=stride)

    @register("WarpAffine")
    def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
        return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)

    def build_to_batch(self):
        models_required_imgsize = [
            "DETR",
            "DINO",
            "RCNN",
            "YOLOv3",
            "CenterNet",
            "BlazeFace",
            "BlazeFace-FPN-SSH",
            "PP-DocLayout-L",
            "PP-DocLayout_plus-L",
            "PP-DocBlockLayout",
            "PP-DocLayoutV2",
        ]
        if any(name in self.model_name for name in models_required_imgsize):
            ordered_required_keys = (
                "img_size",
                "img",
                "scale_factors",
            )
        else:
            ordered_required_keys = ("img", "scale_factors")

        return ToBatch(ordered_required_keys=ordered_required_keys)

    def build_postprocess(self):
        if self.threshold is None:
            self.threshold = self.config.get("draw_threshold", 0.5)
        if not self.layout_nms:
            self.layout_nms = self.config.get("layout_nms", None)
        if self.layout_unclip_ratio is None:
            self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
        if self.layout_merge_bboxes_mode is None:
            self.layout_merge_bboxes_mode = self.config.get(
                "layout_merge_bboxes_mode", None
            )
        return DetPostProcess(labels=self.config["label_list"])


class DetTransformersPredictor(TransformersPredictor):
    """Object detection predictor backed by HuggingFace transformers."""

    def __init__(
        self,
        *args,
        threshold: Optional[Union[float, dict]] = None,
        layout_nms: Optional[bool] = None,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.threshold = threshold
        self.layout_nms = layout_nms
        self.layout_unclip_ratio = layout_unclip_ratio
        self.layout_merge_bboxes_mode = layout_merge_bboxes_mode
        self.read_op = ReadImage(format="RGB")
        self.image_processor, self.infer, self.labels = self._build()
        self.layout_postprocess = DetPostProcess(labels=self.labels)

    def _build_batch_sampler(self):
        return ImageBatchSampler()

    def _get_result_class(self):
        return DetResult

    def _build(self):
        from transformers import AutoImageProcessor, AutoModelForObjectDetection

        image_processor = self._load_pretrained_processor(AutoImageProcessor)
        model = self._load_pretrained_model(AutoModelForObjectDetection)
        self._label_source_model = model
        return image_processor, model, self._resolve_labels()

    def _format_transformers_output(self, prediction: Dict[str, Any]) -> np.ndarray:
        boxes = prediction["boxes"].detach().cpu().numpy()
        scores = prediction["scores"].detach().cpu().numpy()
        labels = prediction["labels"].detach().cpu().numpy()
        if len(boxes) == 0:
            return np.empty((0, 6), dtype=np.float32)
        return np.concatenate(
            [
                labels[:, None].astype(np.float32, copy=False),
                scores[:, None].astype(np.float32, copy=False),
                boxes.astype(np.float32, copy=False),
            ],
            axis=1,
        )

    def _get_target_sizes(self, datas: List[dict]):
        import torch

        return torch.tensor(
            [data["ori_img_size"][::-1] for data in datas], dtype=torch.int64
        )

    def _apply_category_threshold(
        self, boxes: np.ndarray, threshold: Optional[Union[float, dict]]
    ) -> np.ndarray:
        if boxes.size == 0 or not isinstance(threshold, dict):
            return boxes
        selected = []
        for box in boxes:
            cat_id = int(box[0])
            if box[1] > threshold.get(cat_id, 0.5):
                selected.append(box)
        if not selected:
            return np.empty((0, 6), dtype=np.float32)
        return np.asarray(selected, dtype=np.float32)

    def _to_paddlex_boxes(
        self, boxes: np.ndarray, img_size: Tuple[int, int]
    ) -> List[dict]:
        if boxes.size == 0:
            return []
        width, height = img_size
        results = []
        for box in boxes:
            cls_id = int(box[0])
            xmin, ymin, xmax, ymax = box[2:]
            xmin = max(0.0, min(float(xmin), float(width)))
            ymin = max(0.0, min(float(ymin), float(height)))
            xmax = max(0.0, min(float(xmax), float(width)))
            ymax = max(0.0, min(float(ymax), float(height)))
            if xmax <= xmin or ymax <= ymin:
                continue
            label = (
                self.labels[cls_id] if 0 <= cls_id < len(self.labels) else str(cls_id)
            )
            results.append(
                {
                    "cls_id": cls_id,
                    "label": label,
                    "score": float(box[1]),
                    "coordinate": [xmin, ymin, xmax, ymax],
                }
            )
        return results

    def _get_hf_threshold(
        self, threshold: Optional[Union[float, dict]]
    ) -> Tuple[Union[float, dict], float]:
        effective_threshold = threshold if threshold is not None else self.threshold
        if effective_threshold is None:
            effective_threshold = 0.5
        if isinstance(effective_threshold, dict):
            return effective_threshold, 0.0
        return effective_threshold, float(effective_threshold)

    def _get_layout_postprocess_kwargs(
        self,
        layout_nms: bool,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]],
        layout_merge_bboxes_mode: Optional[Union[str, dict]],
    ) -> Dict[str, Any]:
        return {
            "layout_nms": layout_nms or self.layout_nms,
            "layout_unclip_ratio": layout_unclip_ratio or self.layout_unclip_ratio,
            "layout_merge_bboxes_mode": layout_merge_bboxes_mode
            or self.layout_merge_bboxes_mode,
        }

    def _requires_layout_postprocess(
        self, layout_postprocess_kwargs: Dict[str, Any]
    ) -> bool:
        return any(layout_postprocess_kwargs.values())

    def _postprocess_prediction(
        self,
        prediction: Dict[str, Any],
        data: Dict[str, Any],
        effective_threshold: Union[float, dict],
        layout_postprocess_kwargs: Dict[str, Any],
    ) -> List[dict]:
        formatted = self._format_transformers_output(prediction)
        formatted = self._apply_category_threshold(formatted, effective_threshold)
        if self._requires_layout_postprocess(layout_postprocess_kwargs):
            return self.layout_postprocess.apply(
                formatted,
                data["ori_img_size"],
                0.0,
                **layout_postprocess_kwargs,
            )
        return self._to_paddlex_boxes(formatted, data["ori_img_size"])

    def process(
        self,
        batch_data: List[Any],
        threshold: Optional[Union[float, dict]] = None,
        layout_nms: bool = False,
        layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
        layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
    ):
        if not hasattr(self.image_processor, "post_process_object_detection"):
            raise RuntimeError(
                f"{type(self.image_processor).__name__} does not support "
                "`post_process_object_detection`."
            )

        datas = self.read_op(batch_data.instances)
        images = [Image.fromarray(data["img"]) for data in datas]
        effective_threshold, hf_threshold = self._get_hf_threshold(threshold)

        model_inputs = self.preprocess_images(images=images)
        outputs = self.forward(model_inputs)
        predictions = self.postprocess(outputs, datas=datas, threshold=hf_threshold)

        layout_postprocess_kwargs = self._get_layout_postprocess_kwargs(
            layout_nms=layout_nms,
            layout_unclip_ratio=layout_unclip_ratio,
            layout_merge_bboxes_mode=layout_merge_bboxes_mode,
        )
        boxes = [
            self._postprocess_prediction(
                prediction=prediction,
                data=data,
                effective_threshold=effective_threshold,
                layout_postprocess_kwargs=layout_postprocess_kwargs,
            )
            for data, prediction in zip(datas, predictions)
        ]

        return {
            "input_path": batch_data.input_paths,
            "page_index": batch_data.page_indexes,
            "input_img": [data["ori_img"] for data in datas],
            "boxes": boxes,
        }

    def postprocess(self, outputs, *, datas, threshold, **kwargs):
        predictions = self.image_processor.post_process_object_detection(
            outputs,
            threshold=threshold,
            target_sizes=self._get_target_sizes(datas),
        )

        return predictions

    def _resolve_labels(self):
        if self.threshold is None:
            self.threshold = self.model_config.get("draw_threshold", 0.5)
        if self.layout_nms is None:
            self.layout_nms = self.model_config.get("layout_nms", None)
        if self.layout_unclip_ratio is None:
            self.layout_unclip_ratio = self.model_config.get(
                "layout_unclip_ratio", None
            )
        if self.layout_merge_bboxes_mode is None:
            self.layout_merge_bboxes_mode = self.model_config.get(
                "layout_merge_bboxes_mode", None
            )

        labels = self.model_config.get("label_list")
        if not labels:
            label_source = getattr(self, "infer", None) or getattr(
                self, "_label_source_model", None
            )
            id2label = getattr(getattr(label_source, "config", None), "id2label", None)
            if id2label:
                labels = [id2label[idx] for idx in sorted(id2label)]
        if not labels:
            raise ValueError(
                "Unable to resolve label names for object detection model."
            )
        return labels
