Source code for diffengine.datasets.samplers.batch_sampler

# based on https://github.com/open-mmlab/mmdetection/blob/f78af7785ada87f1ced75a2313746e4ba3149760/mmdet/datasets/samplers/batch_sampler.py#L12  # noqa
from collections.abc import Generator

import mmengine
import numpy as np
from torch.utils.data import BatchSampler, Sampler

from diffengine.registry import DATA_SAMPLERS


@DATA_SAMPLERS.register_module()
[docs]class AspectRatioBatchSampler(BatchSampler): """Aspect ratio batch sampler. A sampler wrapper for grouping images with similar aspect ratio into a same batch. Args: ---- sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. bucket_ids (str | None): The path of bucket ids. If ``None``, the bucket ids will be calculated automatically. Default: ``None``. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``. """ def __init__(self, sampler: Sampler, batch_size: int, bucket_ids: str | None = None, *, drop_last: bool = False) -> None: if not isinstance(sampler, Sampler): msg = ("sampler should be an instance of ``Sampler``, but " f"got {sampler}") raise TypeError(msg) if not isinstance(batch_size, int) or batch_size <= 0: msg = ("batch_size should be a positive integer value, but " f"got batch_size={batch_size}") raise ValueError(msg) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last # two groups for w < h and w >= h self._aspect_ratio_buckets: dict = {} # calc aspect ratio if bucket_ids is not None: self.bucket_ids = mmengine.load(bucket_ids) else: self.bucket_ids = [] for idx in range(len(self.sampler.dataset)): data_info = self.sampler.dataset[idx] bucket_id = data_info["inputs"]["img"].size( )[1] / data_info["inputs"]["img"].size()[2] self.bucket_ids.append(bucket_id)
[docs] def __iter__(self) -> Generator: """Get the iterator of the sampler.""" for idx in self.sampler: bucket_id = self.bucket_ids[idx] if bucket_id not in self._aspect_ratio_buckets: self._aspect_ratio_buckets[bucket_id] = [] bucket = self._aspect_ratio_buckets[bucket_id] bucket.append(idx) # yield a batch of indices in the same aspect ratio group if len(bucket) == self.batch_size: yield bucket[:] del bucket[:] # yield the rest data and reset the bucket if not self.drop_last: for v in self._aspect_ratio_buckets.values(): if len(v) > 0: yield v del v self._aspect_ratio_buckets = {}
[docs] def __len__(self) -> int: """Get the length of the sampler.""" total_sample = 0 _, counts = np.unique(self.bucket_ids, return_counts=True) for c in counts: if self.drop_last: total_sample += c // self.batch_size else: total_sample += c // self.batch_size if c % self.batch_size != 0: total_sample += 1 return total_sample