Source code for diffengine.datasets.utils

# flake8: noqa: S311,ANN001
import random

import numpy as np
import torch


[docs]def encode_prompt_sdxl(batch, text_encoders, tokenizers, caption_column, proportion_empty_prompts: float = 0.0, *, is_train: bool = True) -> dict[str, torch.Tensor]: """Encode prompt for Stable Diffusion XL.""" # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt prompt_embeds_list = [] prompt_batch = batch[caption_column] captions = [] for caption in prompt_batch: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, list | np.ndarray): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) with torch.no_grad(): for tokenizer, text_encoder in zip( tokenizers, text_encoders, strict=True): text_inputs = tokenizer( captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final # text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return { "prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu(), }