Source code for wuerstchen_prior

from diffusers import DDPMWuerstchenScheduler
from diffusers.pipelines.wuerstchen import WuerstchenPrior
from transformers import CLIPTextModel, PreTrainedTokenizerFast

from diffengine.models.editors import WuerstchenPriorModel
from diffengine.models.editors.wuerstchen.efficient_net_encoder import (
    EfficientNetEncoder,
)

[docs]decoder_model="warp-ai/wuerstchen"
[docs]prior_model="warp-ai/wuerstchen-prior"
[docs]model = dict(type=WuerstchenPriorModel, decoder_model=decoder_model, tokenizer=dict(type=PreTrainedTokenizerFast.from_pretrained, subfolder="tokenizer"), scheduler=dict(type=DDPMWuerstchenScheduler), text_encoder=dict(type=CLIPTextModel.from_pretrained, subfolder="text_encoder"), image_encoder=dict(type=EfficientNetEncoder, pretrained=True), prior=dict(type=WuerstchenPrior.from_pretrained, subfolder="prior"))