Source code for diffengine.datasets.hf_controlnet_datasets

# flake8: noqa: TRY004,S311
import os
import random
from collections.abc import Sequence
from pathlib import Path

import numpy as np
from datasets import load_dataset
from mmengine.dataset.base_dataset import Compose
from PIL import Image
from torch.utils.data import Dataset

from diffengine.registry import DATASETS

Image.MAX_IMAGE_PIXELS = 1000000000


@DATASETS.register_module()
[docs]class HFControlNetDataset(Dataset): """Dataset for huggingface datasets. Args: ---- dataset (str): Dataset name or path to dataset. image_column (str): Image column name. Defaults to 'image'. condition_column (str): Condition column name for ControlNet. Defaults to 'condition'. caption_column (str): Caption column name. Defaults to 'text'. csv (str): Caption csv file name when loading local folder. Defaults to 'metadata.csv'. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. cache_dir (str, optional): The directory where the downloaded datasets will be stored.Defaults to None. """ def __init__(self, dataset: str, image_column: str = "image", condition_column: str = "condition", caption_column: str = "text", csv: str = "metadata.csv", pipeline: Sequence = (), cache_dir: str | None = None) -> None: self.dataset_name = dataset if Path(dataset).exists(): # load local folder data_file = os.path.join(dataset, csv) self.dataset = load_dataset( "csv", data_files=data_file, cache_dir=cache_dir)["train"] else: # load huggingface online self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"] self.pipeline = Compose(pipeline) self.image_column = image_column self.condition_column = condition_column self.caption_column = caption_column
[docs] def __len__(self) -> int: """Get the length of dataset. Returns ------- int: The length of filtered dataset. """ return len(self.dataset)
[docs] def __getitem__(self, idx: int) -> dict: """Get item. Get the idx-th image and data information of dataset after ``self.pipeline`. Args: ---- idx (int): The index of self.data_list. Returns: ------- dict: The idx-th image and data information of dataset after ``self.pipeline``. """ data_info = self.dataset[idx] image = data_info[self.image_column] if isinstance(image, str): image = Image.open(os.path.join(self.dataset_name, image)) image = image.convert("RGB") condition_image = data_info[self.condition_column] if isinstance(condition_image, str): condition_image = Image.open( os.path.join(self.dataset_name, condition_image)) condition_image = condition_image.convert("RGB") caption = data_info[self.caption_column] if isinstance(caption, str): pass elif isinstance(caption, list | np.ndarray): # take a random caption if there are multiple caption = random.choice(caption) else: msg = (f"Caption column `{self.caption_column}` should " "contain either strings or lists of strings.") raise ValueError(msg) result = { "img": image, "condition_img": condition_image, "text": caption, } return self.pipeline(result)