o
    "j$V                     @   s   d dl Z d dlmZ d dlZddlmZ G dd dZG dd deZG d	d
 d
eZdd Z	G dd deZ
G dd deZG dd deZdddZdd fddZG dd deZdS )    N)Iterable   )	frameworkc                   @   (   e Zd ZdZdd Zdd Zdd ZdS )	Dataseta  
    An abstract class to encapsulate methods and behaviors of datasets.

    All datasets in map-style(dataset samples can be get by a given key)
    should be a subclass of `paddle.io.Dataset`. All subclasses should
    implement following methods:

    :code:`__getitem__`: get sample from dataset with a given index. This
    method is required by reading dataset sample in :code:`paddle.io.DataLoader`.

    :code:`__len__`: return dataset sample number. This method is required
    by some implements of :code:`paddle.io.BatchSampler`

    see :code:`paddle.io.DataLoader`.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> from paddle.io import Dataset

            >>> # define a random dataset
            >>> class RandomDataset(Dataset):
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([784]).astype('float32')
            ...         label = np.random.randint(0, 9, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> dataset = RandomDataset(10)
            >>> for i in range(len(dataset)):
            ...     image, label = dataset[i]
            ...     # do something
    c                 C      d S N selfr	   r	   ]/var/www/html/Deteccion_Ine/venv/lib/python3.10/site-packages/paddle/io/dataloader/dataset.py__init__A      zDataset.__init__c                 C      t dd| jj)N'{}' not implement in class {}__getitem__NotImplementedErrorformat	__class____name__r   idxr	   r	   r   r   D      zDataset.__getitem__c                 C   r   )Nr   __len__r   r
   r	   r	   r   r   J   r   zDataset.__len__Nr   
__module____qualname____doc__r   r   r   r	   r	   r	   r   r      s
    )r   c                   @   s0   e Zd ZdZdd Zdd Zdd Zdd	 Zd
S )IterableDataseta  
    An abstract class to encapsulate methods and behaviors of iterable datasets.

    All datasets in iterable-style (can only get sample one by one sequentially, like
    a Python iterator) should be a subclass of :ref:`api_paddle_io_IterableDataset` . All subclasses should
    implement following methods:

    :code:`__iter__`: yield sample sequentially. This method is required by reading dataset sample in :ref:`api_paddle_io_DataLoader` .

    .. note::
        do not implement :code:`__getitem__` and :code:`__len__` in IterableDataset, should not be called either.

    see :ref:`api_paddle_io_DataLoader` .

    Examples:

        .. code-block:: python
            :name: code-example1

            >>> import numpy as np
            >>> from paddle.io import IterableDataset

            >>> # define a random dataset
            >>> class RandomDataset(IterableDataset):
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __iter__(self):
            ...         for i in range(self.num_samples):
            ...             image = np.random.random([784]).astype('float32')
            ...             label = np.random.randint(0, 9, (1, )).astype('int64')
            ...             yield image, label
            ...
            >>> dataset = RandomDataset(10)
            >>> for img, label in dataset:
            ...     # do something
            ...     ...

    When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
    will yield whole dataset samples, which means samples in dataset will be repeated in
    :attr:`num_workers` times. If it is required for each sample to yield only once, there
    are two methods to configure different copy in each worker process to avoid duplicate data
    among workers as follows. In both the methods, worker information that can be getted in
    a worker process by `paddle.io.get_worker_info` will be needed.

    splitting data copy in each worker in :code:`__iter__`

        .. code-block:: python
            :name: code-example2

            >>> import math
            >>> import paddle
            >>> import numpy as np
            >>> from paddle.io import IterableDataset, DataLoader, get_worker_info

            >>> class SplitedIterableDataset(IterableDataset):
            ...     def __init__(self, start, end):
            ...         self.start = start
            ...         self.end = end
            ...
            ...     def __iter__(self):
            ...         worker_info = get_worker_info()
            ...         if worker_info is None:
            ...             iter_start = self.start
            ...             iter_end = self.end
            ...         else:
            ...             per_worker = int(
            ...                 math.ceil((self.end - self.start) / float(
            ...                     worker_info.num_workers)))
            ...             worker_id = worker_info.id
            ...             iter_start = self.start + worker_id * per_worker
            ...             iter_end = min(iter_start + per_worker, self.end)
            ...
            ...         for i in range(iter_start, iter_end):
            ...             yield np.array([i])
            ...
            >>> dataset = SplitedIterableDataset(start=2, end=9)
            >>> dataloader = DataLoader(
            ...     dataset,
            ...     num_workers=2,
            ...     batch_size=1,
            ...     drop_last=True)
            ...
            >>> for data in dataloader:
            ...     print(data) # doctest: +SKIP("The output depends on the environment.")
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[2]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[3]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[4]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[5]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[6]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[7]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[8]])

    splitting data copy in each worker by :code:`worker_init_fn`

        .. code-block:: python
            :name: code-example3

            >>> import math
            >>> import paddle
            >>> import numpy as np
            >>> from paddle.io import IterableDataset, DataLoader, get_worker_info

            >>> class RangeIterableDataset(IterableDataset):
            ...     def __init__(self, start, end):
            ...         self.start = start
            ...         self.end = end
            ...
            ...     def __iter__(self):
            ...         for i in range(self.start, self.end):
            ...             yield np.array([i])
            ...
            >>> dataset = RangeIterableDataset(start=2, end=9)

            >>> def worker_init_fn(worker_id):
            ...     worker_info = get_worker_info()
            ...
            ...     dataset = worker_info.dataset
            ...     start = dataset.start
            ...     end = dataset.end
            ...     num_per_worker = int(
            ...         math.ceil((end - start) / float(worker_info.num_workers)))
            ...
            ...     worker_id = worker_info.id
            ...     dataset.start = start + worker_id * num_per_worker
            ...     dataset.end = min(dataset.start + num_per_worker, end)
            ...
            >>> dataloader = DataLoader(
            ...     dataset,
            ...     num_workers=2,
            ...     batch_size=1,
            ...     drop_last=True,
            ...     worker_init_fn=worker_init_fn)
            ...
            >>> for data in dataloader:
            ...     print(data) # doctest: +SKIP("The output depends on the environment.")
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[2]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[3]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[4]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[5]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[6]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[7]])
            Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
                [[8]])

    c                 C   r   r   r	   r
   r	   r	   r   r      r   zIterableDataset.__init__c                 C   r   )Nr   __iter__r   r
   r	   r	   r   r       r   zIterableDataset.__iter__c                 C   r   )N/'{}' should not be called for IterableDataset{}r   RuntimeErrorr   r   r   r   r	   r	   r   r      r   zIterableDataset.__getitem__c                 C   r   )Nr!   r   r"   r
   r	   r	   r   r     r   zIterableDataset.__len__N)r   r   r   r   r   r    r   r   r	   r	   r	   r   r   Q   s     !r   c                   @   r   )	TensorDataseta  
    Dataset defined by a list of tensors.

    Each tensor should be in shape of [N, ...], while N is the sample number,
    and ecah tensor contains a field of sample, :code:`TensorDataset` retrieve
    each sample by indexing tensors in the 1st dimension.

    Args:
        tensors(list|tuple): A list/tuple of tensors with same shape in the 1st dimension.

    Returns:
        Dataset: a Dataset instance wrapping tensors.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> import paddle
            >>> from paddle.io import TensorDataset


            >>> input_np = np.random.random([2, 3, 4]).astype('float32')
            >>> input = paddle.to_tensor(input_np)
            >>> label_np = np.random.random([2, 1]).astype('int32')
            >>> label = paddle.to_tensor(label_np)

            >>> dataset = TensorDataset([input, label])

            >>> for i in range(len(dataset)):
            ...     input, label = dataset[i]
            ...     # do something
    c                    s8   t  stdt fdd D sJ d | _d S )Nz1TensorDataset con only be used in imperative modec                 3   s(    | ]}|j d   d  j d  kV  qdS )r   N)shape.0Ztensortensorsr	   r   	<genexpr>0  s    
z)TensorDataset.__init__.<locals>.<genexpr>z0tensors not have same shape of the 1st dimension)r   Zin_dynamic_moder#   allr)   )r   r)   r	   r(   r   r   +  s   

zTensorDataset.__init__c                    s   t  fdd| jD S )Nc                 3   s    | ]}|  V  qd S r   r	   r&   indexr	   r   r*   6  s    z,TensorDataset.__getitem__.<locals>.<genexpr>)tupler)   )r   r-   r	   r,   r   r   5  s   zTensorDataset.__getitem__c                 C   s   | j d jd S Nr   )r)   r%   r
   r	   r	   r   r   8     zTensorDataset.__len__Nr   r	   r	   r	   r   r$     s
    "
r$   c                 C   s(   | d u r| S t | ttfrt| S | gS r   )
isinstancelistr.   )valuer	   r	   r   to_list<  s
   r4   c                   @   r   )	ComposeDataseta  
    A Dataset which composes fields of multiple datasets.

    This dataset is used for composing fileds of multiple map-style
    datasets of same length.

    Args:
        datasets(list of Dataset): List of datasets to be composed.

    Returns:
        Dataset: A Dataset which composes fields of multiple datasets.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> import paddle
            >>> from paddle.io import Dataset, ComposeDataset


            >>> # define a random dataset
            >>> class RandomDataset(Dataset):
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([32]).astype('float32')
            ...         label = np.random.randint(0, 9, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
            >>> for i in range(len(dataset)):
            ...     image1, label1, image2, label2 = dataset[i]
            ...     # do something
    c                 C   s   t || _t| jdksJ dt| jD ]+\}}t|ts"J dt|tr+J d|dkr@t|t| j|d  ks@J dqd S )Nr   "input datasets shoule not be emptyz.each input dataset should be paddle.io.Datasetz'paddle.io.IterableDataset not supported   z"lengths of datasets should be same)r2   datasetslen	enumerater1   r   r   r   r8   idatasetr	   r	   r   r   m  s*   
zComposeDataset.__init__c                 C   s   t | jd S r/   )r9   r8   r
   r	   r	   r   r   |  s   zComposeDataset.__len__c                 C   s*   g }| j D ]}|t||  qt|S r   )r8   extendr4   r.   )r   r   sampler=   r	   r	   r   r     s   
zComposeDataset.__getitem__N)r   r   r   r   r   r   r   r	   r	   r	   r   r5   D  s
    (r5   c                   @   s    e Zd ZdZdd Zdd ZdS )ChainDataseta  
    A Dataset which chains multiple iterable-style datasets.

    This dataset is used for assembling multiple datasets which should
    be :ref:`api_paddle_io_IterableDataset`.

    Args:
        datasets(list of IterableDatasets): List of datasets to be chainned.

    Returns:
        paddle.io.IterableDataset: A Dataset which chains fields of multiple datasets.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> import paddle
            >>> from paddle.io import IterableDataset, ChainDataset


            >>> # define a random dataset
            >>> class RandomDataset(IterableDataset):
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __iter__(self):
            ...         for i in range(10):
            ...             image = np.random.random([32]).astype('float32')
            ...             label = np.random.randint(0, 9, (1, )).astype('int64')
            ...             yield image, label
            ...
            >>> dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
            >>> for image, label in iter(dataset):
            ...     # do something
            ...     ...

    c                 C   sJ   t || _t| jdksJ dt| jD ]\}}t|ts"J dqd S )Nr   r6   z3ChainDataset only support paddle.io.IterableDataset)r2   r8   r9   r:   r1   r   r;   r	   r	   r   r     s   
zChainDataset.__init__c                 c   s    | j D ]}|E d H  qd S r   )r8   )r   r=   r	   r	   r   r      s   
zChainDataset.__iter__N)r   r   r   r   r   r    r	   r	   r	   r   r@     s    'r@   c                   @   r   )	Subseta  
    Subset of a dataset at specified indices.

    Args:
        dataset (Dataset): The whole Dataset.
        indices (sequence): Indices in the whole set selected for subset.

    Returns:
        List[Dataset]: A Dataset which is the subset of the original dataset.

    Examples:

        .. code-block:: python

            >>> import paddle
            >>> from paddle.io import Subset

            >>> # example 1:
            >>> a = paddle.io.Subset(dataset=range(1, 4), indices=[0, 2])
            >>> print(list(a))
            [1, 3]

            >>> # example 2:
            >>> b = paddle.io.Subset(dataset=range(1, 4), indices=[1, 1])
            >>> print(list(b))
            [2, 2]
    c                 C   s   || _ || _d S r   r=   indices)r   r=   rC   r	   r	   r   r     s   
zSubset.__init__c                 C   s   | j | j|  S r   rB   r   r	   r	   r   r     r0   zSubset.__getitem__c                 C   s
   t | jS r   )r9   rC   r
   r	   r	   r   r        
zSubset.__len__Nr   r	   r	   r	   r   rA     s
    rA   c                    sH   t |t krtdtt |  fddtt||D S )a$  
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().

    Returns:
        Datasets: A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> paddle.seed(2023)
            >>> a_list = paddle.io.random_split(range(10), [3, 7])
            >>> print(len(a_list))
            2

            >>> # output of the first subset
            >>> for idx, v in enumerate(a_list[0]):
            ...     print(idx, v) # doctest: +SKIP("The output depends on the environment.")
            0 7
            1 6
            2 5

            >>> # output of the second subset
            >>> for idx, v in enumerate(a_list[1]):
            ...     print(idx, v) # doctest: +SKIP("The output depends on the environment.")
            0 1
            1 9
            2 4
            3 2
            4 0
            5 3
            6 8
    zDSum of input lengths does not equal the length of the input dataset!c                    s&   g | ]\}}t  || | qS r	   )rA   )r'   offsetlengthrB   r	   r   
<listcomp>  s    z random_split.<locals>.<listcomp>)sumr9   
ValueErrorpaddleZrandpermtolistzip_accumulate)r=   lengths	generatorr	   rB   r   random_split  s   +rP   c                 C   s   | | S r   r	   )xyr	   r	   r   <lambda>  s    rS   c                 c   sP    t | }zt|}W n
 ty   Y dS w |V  |D ]
}|||}|V  qdS )a  
    Return running totals

    Args:
        iterable: any iterable object for example dataset.
        y (x): one element in the iterable object.
        fn (x, y): Defaults to lambdax.

    Yields:
        yields total from beginning iterator to current iterator.

    Example code:

        .. code-block:: python

            >>> list(_accumulate([1, 2, 3, 4, 5]))
            [1, 3, 6, 10, 15]

            >>> import operator
            >>> list(_accumulate([1, 2, 3, 4, 5], operator.mul))
            [1, 2, 6, 24, 120]
    N)iternextStopIteration)iterablefnittotalelementr	   r	   r   rM     s   
rM   c                   @   s>   e Zd ZdZedd Zdee fddZdd Z	d	d
 Z
dS )ConcatDataseta  
    Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated

    Returns:
        Dataset: A Dataset which concatenated by multiple datasets.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> import paddle
            >>> from paddle.io import Dataset, ConcatDataset


            >>> # define a random dataset
            >>> class RandomDataset(Dataset):
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([32]).astype('float32')
            ...         label = np.random.randint(0, 9, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
            >>> for i in range(len(dataset)):
            ...     image, label = dataset[i]
            ...     # do something
    c                 C   s6   g d}}| D ]}t |}|||  ||7 }q|S r/   )r9   append)sequencerselr	   r	   r   cumsumf  s   

zConcatDataset.cumsumr8   c                 C   sP   t || _t| jdksJ d| jD ]}t|trJ dq| | j| _d S )Nr   z(datasets should not be an empty iterablez.ConcatDataset does not support IterableDataset)r2   r8   r9   r1   r   rc   cumulative_sizes)r   r8   dr	   r	   r   r   o  s   

zConcatDataset.__init__c                 C   s
   | j d S )N)rd   r
   r	   r	   r   r   z  rD   zConcatDataset.__len__c                 C   sf   |dk r| t | krtdt | | }t| j|}|dkr#|}n	|| j|d   }| j| | S )Nr   z8absolute value of index should not exceed dataset lengthr7   )r9   rI   bisectbisect_rightrd   r8   )r   r   Zdataset_idxZ
sample_idxr	   r	   r   r   }  s   zConcatDataset.__getitem__N)r   r   r   r   staticmethodrc   r   r   r   r   r   r	   r	   r	   r   r\   >  s    '
r\   r   )rg   typingr   rJ    r   r   r   r$   r4   r5   r@   rA   rP   rM   r\   r	   r	   r	   r   <module>   s   : 84B5
(8#