diffengine.datasets.hf_dpo_datasets

Module Contents

Classes

HFDPODataset

DPO Dataset for huggingface datasets.

class diffengine.datasets.hf_dpo_datasets.HFDPODataset(dataset, image_columns=None, caption_column='text', label_column='label_0', csv='metadata.csv', pipeline=(), split='train', cache_dir=None)[source]

Bases: torch.utils.data.Dataset

DPO Dataset for huggingface datasets.

Args:

dataset (str): Dataset name or path to dataset. image_columns (list[str]): Image column names. Defaults to [‘image’]. caption_column (str): Caption column name. Defaults to ‘text’. label_column (str): Label column name of whether image_columns[0] is

better than image_columns[1]. Defaults to ‘label_0’.

csv (str): Caption csv file name when loading local folder.

Defaults to ‘metadata.csv’.

pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. split (str): Dataset split. Defaults to ‘train’. cache_dir (str, optional): The directory where the downloaded datasets

will be stored.Defaults to None.

__len__()[source]

Get the length of dataset.

Returns:

int

Return type:

The length of filtered dataset.

__getitem__(idx)[source]

Get item.

Get the idx-th image and data information of dataset after ``self.pipeline`.

Args:

idx (int): The index of self.data_list.

Returns:

dict: The idx-th image and data information of dataset after self.pipeline.

Parameters:

idx (int) –

Return type:

dict

Parameters:
  • dataset (str) –

  • image_columns (list[str] | None) –

  • caption_column (str) –

  • label_column (str) –

  • csv (str) –

  • pipeline (collections.abc.Sequence) –

  • split (str) –

  • cache_dir (str | None) –