1
0
Fork 0

fix SDXL patch signature

This commit is contained in:
Sean Sube 2024-03-02 22:35:04 -06:00
parent 5da846e41d
commit a1657a6b09
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 7 additions and 5 deletions

View File

@ -1,5 +1,5 @@
from types import SimpleNamespace
from typing import Optional
from typing import List, Optional, Union
import numpy as np
import torch
@ -60,7 +60,7 @@ def encode_prompt_compel(
prompt_embeds = compel(prompt)
if negative_prompt is not None:
negative_prompt_embeds = compel(self, negative_prompt)
negative_prompt_embeds = compel(negative_prompt)
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (
@ -78,12 +78,14 @@ def encode_prompt_compel(
def encode_prompt_compel_sdxl(
self: OnnxStableDiffusionPipeline,
prompt: str,
prompt: Union[str, List[str]],
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[str] = None,
negative_prompt: Optional[Union[str, list]],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: int = 0,
) -> np.ndarray:
wrapped_encoder = wrap_encoder(self.text_encoder)
@ -98,7 +100,7 @@ def encode_prompt_compel_sdxl(
prompt_embeds, prompt_pooled = compel(prompt)
if negative_prompt is not None:
negative_prompt_embeds, negative_pooled = compel(self, negative_prompt)
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
if negative_prompt_embeds is not None:
[prompt_embeds, negative_prompt_embeds] = (