fix SDXL patch signature
This commit is contained in:
parent
5da846e41d
commit
a1657a6b09
|
@ -1,5 +1,5 @@
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -60,7 +60,7 @@ def encode_prompt_compel(
|
||||||
prompt_embeds = compel(prompt)
|
prompt_embeds = compel(prompt)
|
||||||
|
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt_embeds = compel(self, negative_prompt)
|
negative_prompt_embeds = compel(negative_prompt)
|
||||||
|
|
||||||
if negative_prompt_embeds is not None:
|
if negative_prompt_embeds is not None:
|
||||||
[prompt_embeds, negative_prompt_embeds] = (
|
[prompt_embeds, negative_prompt_embeds] = (
|
||||||
|
@ -78,12 +78,14 @@ def encode_prompt_compel(
|
||||||
|
|
||||||
def encode_prompt_compel_sdxl(
|
def encode_prompt_compel_sdxl(
|
||||||
self: OnnxStableDiffusionPipeline,
|
self: OnnxStableDiffusionPipeline,
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
num_images_per_prompt: int,
|
num_images_per_prompt: int,
|
||||||
do_classifier_free_guidance: bool,
|
do_classifier_free_guidance: bool,
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[Union[str, list]],
|
||||||
prompt_embeds: Optional[np.ndarray] = None,
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
skip_clip_states: int = 0,
|
skip_clip_states: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
wrapped_encoder = wrap_encoder(self.text_encoder)
|
wrapped_encoder = wrap_encoder(self.text_encoder)
|
||||||
|
@ -98,7 +100,7 @@ def encode_prompt_compel_sdxl(
|
||||||
prompt_embeds, prompt_pooled = compel(prompt)
|
prompt_embeds, prompt_pooled = compel(prompt)
|
||||||
|
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt_embeds, negative_pooled = compel(self, negative_prompt)
|
negative_prompt_embeds, negative_pooled = compel(negative_prompt)
|
||||||
|
|
||||||
if negative_prompt_embeds is not None:
|
if negative_prompt_embeds is not None:
|
||||||
[prompt_embeds, negative_prompt_embeds] = (
|
[prompt_embeds, negative_prompt_embeds] = (
|
||||||
|
|
Loading…
Reference in New Issue