Source code for diffengine.models.editors.esd.esd_xl_data_preprocessor

from typing import Union

import torch
from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor

from diffengine.registry import MODELS


@MODELS.register_module()
[docs]class ESDXLDataPreprocessor(BaseDataPreprocessor): """ESDXLDataPreprocessor."""
[docs] def forward(self, data: dict, training: bool = False) -> Union[dict, list]: # noqa """Preprocesses the data into the model input format. After the data pre-processing of :meth:`cast_data`, ``forward`` will stack the input tensor list to a batch tensor at the first dimension. Args: ---- data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation. Returns: ------- dict or list: Data in the same format as the model input. """ # pre-compute text embeddings data["inputs"]["prompt_embeds"] = torch.stack( data["inputs"]["prompt_embeds"]) data["inputs"]["pooled_prompt_embeds"] = torch.stack( data["inputs"]["pooled_prompt_embeds"]) data["inputs"]["null_prompt_embeds"] = torch.stack( data["inputs"]["null_prompt_embeds"]) data["inputs"]["null_pooled_prompt_embeds"] = torch.stack( data["inputs"]["null_pooled_prompt_embeds"]) return super().forward(data)