diffengine.datasets.hf_dreambooth_datasets¶
Module Contents¶
Classes¶
DreamBooth Dataset for huggingface datasets. |
- class diffengine.datasets.hf_dreambooth_datasets.HFDreamBoothDataset(dataset, instance_prompt, image_column='image', dataset_sub_dir=None, class_image_config=None, class_prompt=None, pipeline=(), csv=None, cache_dir=None)[source]¶
Bases:
torch.utils.data.DatasetDreamBooth Dataset for huggingface datasets.
Args:¶
dataset (str): Dataset name. instance_prompt (str):
The prompt with identifier specifying the instance.
image_column (str): Image column name. Defaults to ‘image’. dataset_sub_dir (optional, str): Dataset sub directory name. class_image_config (dict):
- model (str): pretrained model name of stable diffusion to
create training data of class images. Defaults to ‘runwayml/stable-diffusion-v1-5’.
- data_dir (str): A folder containing the training data of class
images. Defaults to ‘work_dirs/class_image’.
- num_images (int): Minimal class images for prior preservation
loss. If there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt. Defaults to 200.
- recreate_class_images (bool): Whether to re create all class
images. Defaults to True.
- class_prompt (Optional[str]): The prompt to specify images in the same
class as provided instance images. Defaults to None.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. csv (str, optional): Image path csv file name when loading local
folder. If None, the dataset will be loaded from image folders. Defaults to None.
- cache_dir (str, optional): The directory where the downloaded datasets
will be stored.Defaults to None.
- generate_class_image(class_image_config)[source]¶
Generate class images for prior preservation loss.
- Parameters:
class_image_config (dict) –
- Return type:
None
- Parameters:
dataset (str) –
instance_prompt (str) –
image_column (str) –
dataset_sub_dir (str | None) –
class_image_config (dict | None) –
class_prompt (str | None) –
pipeline (collections.abc.Sequence) –
csv (str | None) –
cache_dir (str | None) –