# flake8: noqa: S311,ANN001
import random
import numpy as np
import torch
[docs]def encode_prompt_sdxl(batch,
text_encoders,
tokenizers,
caption_column,
proportion_empty_prompts: float = 0.0,
*,
is_train: bool = True) -> dict[str, torch.Tensor]:
"""Encode prompt for Stable Diffusion XL."""
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
prompt_embeds_list = []
prompt_batch = batch[caption_column]
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, list | np.ndarray):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
with torch.no_grad():
for tokenizer, text_encoder in zip(
tokenizers, text_encoders, strict=True):
text_inputs = tokenizer(
captions,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final
# text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return {
"prompt_embeds": prompt_embeds.cpu(),
"pooled_prompt_embeds": pooled_prompt_embeds.cpu(),
}