Source code for diffengine.models.editors.ip_adapter.pipeline

# flake8: noqa
import torch
from diffusers import StableDiffusionXLPipeline


[docs]class StableDiffusionXLPipelineCustomIPAdapter(StableDiffusionXLPipeline): """Custom IP Adapter for the StableDiffusionXLPipeline class. The difference between this class and the original StableDiffusionXLPipeline class is that this class uses the hidden states from the `hidden_states_idx` layer of the image encoder to encode the image. Args: *args: Variable length argument list. hidden_states_idx (int): Index of the hidden states to be used. Defaults to -2. **kwargs: Arbitrary keyword arguments. """ def __init__(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler, image_encoder=None, feature_extractor=None, force_zeros_for_empty_prompt=True, add_watermarker=None, hidden_states_idx: int = -2): super().__init__(vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, add_watermarker=add_watermarker) self.hidden_states_idx = hidden_states_idx
[docs] def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): """Encodes the image. Args: image: The input image to be encoded. device: The device to be used for encoding. num_images_per_prompt: The number of images per prompt. output_hidden_states: Whether to output hidden states. Defaults to None. Returns: image_enc_hidden_states: Encoded hidden states of the image. uncond_image_enc_hidden_states: Encoded hidden states of the unconditional image. """ dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: if self.hidden_states_idx == -1: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).last_hidden_state else: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[self.hidden_states_idx] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) if self.hidden_states_idx == -1: uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).last_hidden_state else: uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[self.hidden_states_idx] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0, ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
[docs]class StableDiffusionXLPipelineTimmIPAdapter(StableDiffusionXLPipeline): """Timm IP Adapter for the StableDiffusionXLPipeline class. The difference between this class and the original StableDiffusionXLPipeline class is that this class uses the timm library for the image encoder. Args: *args: Variable length argument list. hidden_states_idx (int): Index of the hidden states to be used. Defaults to -2. **kwargs: Arbitrary keyword arguments. """ def __init__(self, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler, image_encoder=None, feature_extractor=None, force_zeros_for_empty_prompt=True, add_watermarker=None): super().__init__(vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, force_zeros_for_empty_prompt=force_zeros_for_empty_prompt, add_watermarker=add_watermarker)
[docs] def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): """Encodes the image. Args: image: The input image to be encoded. device: The device to be used for encoding. num_images_per_prompt: The number of images per prompt. output_hidden_states: Whether to output hidden states. Defaults to None. Returns: image_enc_hidden_states: Encoded hidden states of the image. uncond_image_enc_hidden_states: Encoded hidden states of the unconditional image. """ dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image).unsqueeze(0) image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder.forward_features(image) image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder.forward_features( torch.zeros_like(image), ) uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0, ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder.forward_features(image) image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds
@property
[docs] def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue if not hasattr(model, "_hf_hook"): return self.device for module in model.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device