diffengine.datasets.hf_datasets

Module Contents

Classes

HFDataset

Dataset for huggingface datasets.

HFDatasetPreComputeEmbs

Dataset for huggingface datasets.

class diffengine.datasets.hf_datasets.HFDataset(dataset, image_column='image', caption_column='text', csv='metadata.csv', pipeline=(), cache_dir=None)[source]

Bases: torch.utils.data.Dataset

Dataset for huggingface datasets.

Args:

dataset (str): Dataset name or path to dataset. image_column (str): Image column name. Defaults to ‘image’. caption_column (str): Caption column name. Defaults to ‘text’. csv (str): Caption csv file name when loading local folder.

Defaults to ‘metadata.csv’.

pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. cache_dir (str, optional): The directory where the downloaded datasets

will be stored.Defaults to None.

__len__()[source]

Get the length of dataset.

Returns:

int

Return type:

The length of filtered dataset.

__getitem__(idx)[source]

Get item.

Get the idx-th image and data information of dataset after ``self.pipeline`.

Args:

idx (int): The index of self.data_list.

Returns:

dict: The idx-th image and data information of dataset after self.pipeline.

Parameters:

idx (int) –

Return type:

dict

Parameters:
  • dataset (str) –

  • image_column (str) –

  • caption_column (str) –

  • csv (str) –

  • pipeline (collections.abc.Sequence) –

  • cache_dir (str | None) –

class diffengine.datasets.hf_datasets.HFDatasetPreComputeEmbs(*args, model='stabilityai/stable-diffusion-xl-base-1.0', text_hasher='text', device='cuda', proportion_empty_prompts=0.0, **kwargs)[source]

Bases: HFDataset

Dataset for huggingface datasets.

The difference from HFDataset is
  1. pre-compute Text Encoder embeddings to save memory.

Args:

model (str): pretrained model name of stable diffusion xl.

Defaults to ‘stabilityai/stable-diffusion-xl-base-1.0’.

text_hasher (str): Text embeddings hasher name. Defaults to ‘text’. device (str): Device used to compute embeddings. Defaults to ‘cuda’. proportion_empty_prompts (float): The probabilities to replace empty

text. Defaults to 0.9.

__getitem__(idx)[source]

Get item.

Get the idx-th image and data information of dataset after ``self.train_transforms`.

Args:

idx (int): The index of self.data_list.

Returns:

dict: The idx-th image and data information of dataset after self.train_transforms.

Parameters:

idx (int) –

Return type:

dict

Parameters:
  • model (str) –

  • text_hasher (str) –

  • device (str) –

  • proportion_empty_prompts (float) –