diff --git a/api/onnx_web/prompt/compel.py b/api/onnx_web/prompt/compel.py index c64a907e..e1dca077 100644 --- a/api/onnx_web/prompt/compel.py +++ b/api/onnx_web/prompt/compel.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import Optional +from typing import List, Optional, Union import numpy as np import torch @@ -60,7 +60,7 @@ def encode_prompt_compel( prompt_embeds = compel(prompt) if negative_prompt is not None: - negative_prompt_embeds = compel(self, negative_prompt) + negative_prompt_embeds = compel(negative_prompt) if negative_prompt_embeds is not None: [prompt_embeds, negative_prompt_embeds] = ( @@ -78,12 +78,14 @@ def encode_prompt_compel( def encode_prompt_compel_sdxl( self: OnnxStableDiffusionPipeline, - prompt: str, + prompt: Union[str, List[str]], num_images_per_prompt: int, do_classifier_free_guidance: bool, - negative_prompt: Optional[str] = None, + negative_prompt: Optional[Union[str, list]], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, skip_clip_states: int = 0, ) -> np.ndarray: wrapped_encoder = wrap_encoder(self.text_encoder) @@ -98,7 +100,7 @@ def encode_prompt_compel_sdxl( prompt_embeds, prompt_pooled = compel(prompt) if negative_prompt is not None: - negative_prompt_embeds, negative_pooled = compel(self, negative_prompt) + negative_prompt_embeds, negative_pooled = compel(negative_prompt) if negative_prompt_embeds is not None: [prompt_embeds, negative_prompt_embeds] = (