diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 89f1301b..274ab407 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -81,10 +81,11 @@ class BlendImg2ImgStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 3840fdd2..ce1f04df 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -130,10 +130,11 @@ class SourceTxt2ImgStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 85ddc079..cdc3a067 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -99,10 +99,11 @@ class UpscaleOutpaintStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 9d5a7b32..cf784b05 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -48,13 +48,14 @@ class UpscaleStableDiffusionStage(BaseStage): ) generator = torch.manual_seed(params.seed) - prompt_embeds = encode_prompt( - pipeline, - prompt_pairs, - num_images_per_prompt=params.batch, - do_classifier_free_guidance=params.do_cfg(), - ) - pipeline.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipeline, + prompt_pairs, + num_images_per_prompt=params.batch, + do_classifier_free_guidance=params.do_cfg(), + ) + pipeline.unet.set_prompts(prompt_embeds) outputs = [] for source in sources: