Source code for diffengine.datasets.hf_datasets

# flake8: noqa: TRY004,S311
import functools
import gc
import os
import random
from collections.abc import Sequence
from pathlib import Path

import numpy as np
import torch
from datasets import load_dataset
from datasets.fingerprint import Hasher
from mmengine.dataset.base_dataset import Compose
from PIL import Image
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)

from diffengine.datasets.utils import encode_prompt_sdxl
from diffengine.registry import DATASETS

Image.MAX_IMAGE_PIXELS = 1000000000


@DATASETS.register_module()
[docs]class HFDataset(Dataset): """Dataset for huggingface datasets. Args: ---- dataset (str): Dataset name or path to dataset. image_column (str): Image column name. Defaults to 'image'. caption_column (str): Caption column name. Defaults to 'text'. csv (str): Caption csv file name when loading local folder. Defaults to 'metadata.csv'. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. cache_dir (str, optional): The directory where the downloaded datasets will be stored.Defaults to None. """ def __init__(self, dataset: str, image_column: str = "image", caption_column: str = "text", csv: str = "metadata.csv", pipeline: Sequence = (), cache_dir: str | None = None) -> None: self.dataset_name = dataset if Path(dataset).exists(): # load local folder data_file = os.path.join(dataset, csv) self.dataset = load_dataset( "csv", data_files=data_file, cache_dir=cache_dir)["train"] else: # load huggingface online self.dataset = load_dataset(dataset, cache_dir=cache_dir)["train"] self.pipeline = Compose(pipeline) self.image_column = image_column self.caption_column = caption_column
[docs] def __len__(self) -> int: """Get the length of dataset. Returns ------- int: The length of filtered dataset. """ return len(self.dataset)
[docs] def __getitem__(self, idx: int) -> dict: """Get item. Get the idx-th image and data information of dataset after ``self.pipeline`. Args: ---- idx (int): The index of self.data_list. Returns: ------- dict: The idx-th image and data information of dataset after ``self.pipeline``. """ data_info = self.dataset[idx] image = data_info[self.image_column] if isinstance(image, str): image = Image.open(os.path.join(self.dataset_name, image)) image = image.convert("RGB") caption = data_info[self.caption_column] if isinstance(caption, str): pass elif isinstance(caption, list | np.ndarray): # take a random caption if there are multiple caption = random.choice(caption) else: msg = (f"Caption column `{self.caption_column}` should " "contain either strings or lists of strings.") raise ValueError(msg) result = {"img": image, "text": caption} return self.pipeline(result)
@DATASETS.register_module()
[docs]class HFDatasetPreComputeEmbs(HFDataset): """Dataset for huggingface datasets. The difference from HFDataset is 1. pre-compute Text Encoder embeddings to save memory. Args: ---- model (str): pretrained model name of stable diffusion xl. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. text_hasher (str): Text embeddings hasher name. Defaults to 'text'. device (str): Device used to compute embeddings. Defaults to 'cuda'. proportion_empty_prompts (float): The probabilities to replace empty text. Defaults to 0.9. """ def __init__(self, *args, model: str = "stabilityai/stable-diffusion-xl-base-1.0", text_hasher: str = "text", device: str = "cuda", proportion_empty_prompts: float = 0.0, **kwargs) -> None: super().__init__(*args, **kwargs) tokenizer_one = AutoTokenizer.from_pretrained( model, subfolder="tokenizer", use_fast=False) tokenizer_two = AutoTokenizer.from_pretrained( model, subfolder="tokenizer_2", use_fast=False) text_encoder_one = CLIPTextModel.from_pretrained( model, subfolder="text_encoder").to(device) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( model, subfolder="text_encoder_2").to(device) new_fingerprint = Hasher.hash(text_hasher) compute_embeddings_fn = functools.partial( encode_prompt_sdxl, text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[tokenizer_one, tokenizer_two], proportion_empty_prompts=proportion_empty_prompts, caption_column=self.caption_column, ) self.dataset = self.dataset.map( compute_embeddings_fn, batched=True, batch_size=32, new_fingerprint=new_fingerprint) del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two gc.collect() torch.cuda.empty_cache()
[docs] def __getitem__(self, idx: int) -> dict: """Get item. Get the idx-th image and data information of dataset after ``self.train_transforms`. Args: ---- idx (int): The index of self.data_list. Returns: ------- dict: The idx-th image and data information of dataset after ``self.train_transforms``. """ data_info = self.dataset[idx] image = data_info[self.image_column] if isinstance(image, str): image = Image.open(os.path.join(self.dataset_name, image)) image = image.convert("RGB") result = { "img": image, "prompt_embeds": data_info["prompt_embeds"], "pooled_prompt_embeds": data_info["pooled_prompt_embeds"], } return self.pipeline(result)