# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
from collections import defaultdict

import numpy as np
from PIL import Image, ImageOps

from .....utils.deps import function_requires_deps, is_dep_available
from .....utils.file_interface import custom_open

if is_dep_available("opencv-contrib-python"):
    import cv2
if is_dep_available("matplotlib"):
    import matplotlib.pyplot as plt
    from matplotlib.backends.backend_agg import FigureCanvasAgg


# show data samples
def simple_analyse(dataset_path, max_recorded_sample_cnts=20, show_label=True):
    """
    Analyse the dataset samples by return not nore than
    max_recorded_sample_cnts image path and label path

    Args:
        dataset_path (str): dataset path
        max_recorded_sample_cnts (int, optional): the number to return. Default: 50.

    Returns:
        tuple: tuple of sample number, image path and label path for train, val and text subdataset.

    """
    tags = ["train", "val", "test"]
    sample_cnts = defaultdict(int)
    img_paths = defaultdict(list)
    lab_paths = defaultdict(list)
    lab_infos = defaultdict(list)
    res = [None] * 9
    delim = "\t"

    for tag in tags:
        file_list = os.path.join(dataset_path, f"{tag}.txt")
        if not os.path.exists(file_list):
            if tag in ("train", "val"):
                res.insert(0, "数据集不符合规范，请先通过数据校准")
                return res
            else:
                continue
        else:
            with custom_open(file_list, "r") as f:
                all_lines = f.readlines()

            # Each line corresponds to a sample
            sample_cnts[tag] = len(all_lines)

            for idx, line in enumerate(all_lines):
                parts = line.strip("\n").split(delim)
                if len(line.strip("\n")) < 1:
                    continue
                if tag in ("train", "val"):
                    valid_num_parts_lst = [2]
                else:
                    valid_num_parts_lst = [1, 2]
                if len(parts) not in valid_num_parts_lst and len(line.strip("\n")) > 1:
                    res.insert(0, "数据集的标注文件不符合规范")
                    return res

                if len(parts) == 2:
                    img_path, lab_path = parts
                else:
                    # len(parts) == 1
                    img_path = parts[0]
                    lab_path = None

                # check det label
                if len(img_paths[tag]) < max_recorded_sample_cnts:
                    img_path = os.path.join(dataset_path, img_path)
                    if lab_path is not None:
                        label = json.loads(lab_path)
                        boxes = []
                        for item in label:
                            if "points" not in item or "transcription" not in item:
                                res.insert(0, "数据集的标注文件不符合规范")
                                return res

                            box = np.array(item["points"])
                            if box.shape[1] != 2:
                                res.insert(0, "数据集的标注文件不符合规范")
                                return res
                            boxes.append(box)
                            txt = item["transcription"]
                            if not isinstance(txt, str):
                                res.insert(0, "数据集的标注文件不符合规范")
                                return res
                        if show_label:
                            lab_img = show_label_img(img_path, boxes)

                    img_paths[tag].append(img_path)
                    if show_label:
                        lab_paths[tag].append(lab_img)
                    else:
                        lab_infos[tag].append({"img_path": img_path, "box": boxes})

    if show_label:
        return (
            "完成数据分析",
            sample_cnts[tags[0]],
            sample_cnts[tags[1]],
            sample_cnts[tags[2]],
            img_paths[tags[0]],
            img_paths[tags[1]],
            img_paths[tags[2]],
            lab_paths[tags[0]],
            lab_paths[tags[1]],
            lab_paths[tags[2]],
        )
    else:
        return (
            "完成数据分析",
            sample_cnts[tags[0]],
            sample_cnts[tags[1]],
            sample_cnts[tags[2]],
            img_paths[tags[0]],
            img_paths[tags[1]],
            img_paths[tags[2]],
            lab_infos[tags[0]],
            lab_infos[tags[1]],
            lab_infos[tags[2]],
        )


@function_requires_deps("opencv-contrib-python")
def show_label_img(img_path, dt_boxes):
    """draw ocr detection label"""
    img = cv2.imread(img_path)
    for box in dt_boxes:
        box = np.array(box).astype(np.int32).reshape(-1, 2)
        cv2.polylines(img, [box], True, color=(0, 255, 0), thickness=3)
    return img[:, :, ::-1]


@function_requires_deps("matplotlib", "opencv-contrib-python")
def deep_analyse(dataset_path, output):
    """class analysis for dataset"""
    sample_results = simple_analyse(
        dataset_path, max_recorded_sample_cnts=float("inf"), show_label=False
    )
    lab_infos = sample_results[-3] + sample_results[-2] + sample_results[-1]
    defaultdict(int)
    img_shapes = []  # w, h
    ratios_w = []
    ratios_h = []
    for info in lab_infos:
        img = np.asarray(ImageOps.exif_transpose(Image.open(info["img_path"])))
        img_h, img_w = np.shape(img)[:2]
        img_shapes.append([img_w, img_h])
        for box in info["box"]:
            box = np.array(box).astype(np.int32).reshape(-1, 2)
            box_w, box_h = np.max(box, axis=0) - np.min(box, axis=0)
            ratio_w = box_w / img_w
            ratio_h = box_h / img_h
            ratios_w.append(ratio_w)
            ratios_h.append(ratio_h)
    m_w_img, m_h_img = np.mean(img_shapes, axis=0)  # mean img shape
    m_num_box = len(ratios_w) / len(lab_infos)  # num box per img

    ratio_w = [i * 1000 for i in ratios_w]
    ratio_h = [i * 1000 for i in ratios_h]
    w_bins = int((max(ratio_w) - min(ratio_w)) // 10)
    h_bins = int((max(ratio_h) - min(ratio_h)) // 10)

    fig, ax = plt.subplots()
    ax.hist(ratio_w, bins=w_bins, rwidth=0.8, color="yellowgreen")
    ax.set_xlabel("Width rate *1000")
    ax.set_ylabel("number")
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    bar_array = np.asarray(canvas.buffer_rgba()).reshape(
        int(height), int(width), 4
    )[:, :, :3]

    # pie
    fig, ax = plt.subplots()
    ax.hist(ratio_h, bins=h_bins, rwidth=0.8, color="pink")
    ax.set_xlabel("Height rate *1000")
    ax.set_ylabel("number")
    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    pie_array = np.asarray(canvas.buffer_rgba()).reshape(
        int(height), int(width), 4
    )[:, :, :3]

    os.makedirs(output, exist_ok=True)
    fig_path = os.path.join(output, "histogram.png")
    img_array = np.concatenate((bar_array, pie_array), axis=1)
    cv2.imwrite(fig_path, img_array)
    return {"histogram": os.path.join("check_dataset", "histogram.png")}
    # return {
    #     "图像平均宽度": m_w_img,
    #     "图像平均高度": m_h_img,
    #     "每张图平均文本检测框数量": m_num_box,
    #     "检测框相对宽度分布图": fig1_path,
    #     "检测框相对高度分布图": fig2_path
    # }
