# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import os
import platform
from collections import defaultdict

import numpy as np

from .....utils.deps import function_requires_deps, is_dep_available
from .....utils.file_interface import custom_open
from .....utils.fonts import PINGFANG_FONT
from .....utils.logging import warning

if is_dep_available("opencv-contrib-python"):
    import cv2
if is_dep_available("matplotlib"):
    import matplotlib.pyplot as plt
    from matplotlib import font_manager
    from matplotlib.backends.backend_agg import FigureCanvasAgg


def simple_analyse(dataset_path, images_dict):
    """
    Analyse the dataset samples by return image path and label path

    Args:
        dataset_path (str): dataset path
        ds_meta (dict): dataset meta
        images_dict (dict): train, val and test image path

    Returns:
        tuple: tuple of sample number, image path and label path for train, val and text subdataset.

    """
    tags = ["train", "val", "test"]
    sample_cnts = defaultdict(int)
    img_paths = defaultdict(list)
    res = [None] * 6

    for tag in tags:
        file_list = os.path.join(dataset_path, f"{tag}.txt")
        if not os.path.exists(file_list):
            if tag in ("train", "val"):
                res.insert(0, "数据集不符合规范，请先通过数据校准")
                return res
            else:
                continue
        else:
            with custom_open(file_list, "r") as f:
                all_lines = f.readlines()

            # Each line corresponds to a sample
            sample_cnts[tag] = len(all_lines)
            img_paths[tag] = images_dict[tag]

    return (
        "完成数据分析",
        sample_cnts[tags[0]],
        sample_cnts[tags[1]],
        sample_cnts[tags[2]],
        img_paths[tags[0]],
        img_paths[tags[1]],
        img_paths[tags[2]],
    )


@function_requires_deps("matplotlib", "opencv-contrib-python")
def deep_analyse(dataset_path, output, datatype="MSTextRecDataset"):
    """class analysis for dataset"""
    tags = ["train", "val"]
    labels_cnt = {}
    x_max = []
    classes_max = []
    for tag in tags:
        image_path = os.path.join(dataset_path, f"{tag}.txt")
        str_nums = []
        with custom_open(image_path, "r") as f:
            lines = f.readlines()
        for line in lines:
            line = line.strip().split("\t")
            if len(line) != 2:
                warning(f"Error in {line}.")
                continue
            str_nums.append(len(line[1]))
        if datatype == "LaTeXOCRDataset":
            max_length = min(768, max(str_nums))
            interval = 20
        else:
            max_length = min(100, max(str_nums))
            interval = 5
        start = 0
        for i in range(1, math.ceil((max_length / interval))):
            stop = i * interval
            num_str = sum(start < i <= stop for i in str_nums)
            labels_cnt[f"{start}-{stop}"] = num_str
            start = stop
        if sum(max_length < i for i in str_nums) != 0:
            labels_cnt[f"> {max_length}"] = sum(max_length < i for i in str_nums)
        if tag == "train":
            cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
            x_train = np.arange(len(cnts_train))
            if len(x_train) > len(x_max):
                x_max = x_train
                classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
        elif tag == "val":
            cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
            x_val = np.arange(len(cnts_val))
            if len(x_val) > len(x_max):
                x_max = x_val
                classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]

    width = 0.3

    # bar
    os_system = platform.system().lower()
    if os_system == "windows":
        plt.rcParams["font.sans-serif"] = "FangSong"
    else:
        font = font_manager.FontProperties(fname=PINGFANG_FONT.path, size=15)
    if datatype == "LaTeXOCRDataset":
        fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
        xlabel_name = "公式长度区间"
    else:
        fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
        xlabel_name = "文本字长度区间"
    ax.bar(x_train, cnts_train, width=0.3, label="train")
    ax.bar(x_val + width, cnts_val, width=0.3, label="val")
    plt.xticks(x_max + width / 2, classes_max, rotation=90)
    plt.legend(prop={"size": 18})
    ax.set_xlabel(
        xlabel_name,
        fontproperties=None if os_system == "windows" else font,
        fontsize=12,
    )
    ax.set_ylabel(
        "图片数量", fontproperties=None if os_system == "windows" else font, fontsize=12
    )

    canvas = FigureCanvasAgg(fig)
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    pie_array = np.asarray(canvas.buffer_rgba()).reshape(
        int(height), int(width), 4
    )[:, :, :3]
    fig1_path = os.path.join(output, "histogram.png")
    cv2.imwrite(fig1_path, pie_array)

    return {"histogram": os.path.join("check_dataset", "histogram.png")}
