From 6eb014cec88b33a98c5f307a0f9b76d3a01e670e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Nov 2023 14:12:28 -0600 Subject: [PATCH] fix(api): enable prompt alternatives for SDXL --- api/onnx_web/chain/blend_img2img.py | 4 +--- api/onnx_web/chain/source_txt2img.py | 4 +--- api/onnx_web/diffusers/utils.py | 12 +++++++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index af181c10..89f1301b 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -84,9 +84,7 @@ class BlendImg2ImgStage(BaseStage): prompt_embeds = encode_prompt( pipe, prompt_pairs, params.batch, params.do_cfg() ) - - if not params.is_xl(): - pipe.unet.set_prompts(prompt_embeds) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 40377fe0..3840fdd2 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -133,9 +133,7 @@ class SourceTxt2ImgStage(BaseStage): prompt_embeds = encode_prompt( pipe, prompt_pairs, params.batch, params.do_cfg() ) - - if not params.is_xl(): - pipe.unet.set_prompts(prompt_embeds) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index b720e711..9686a48e 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -17,14 +17,15 @@ LATENT_CHANNELS = 4 LATENT_FACTOR = 8 MAX_TOKENS_PER_GROUP = 77 +ANY_TOKEN = compile(r"\<([^\>]*)\>") CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") -WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") REGION_TOKEN = compile( r"\]+)\>" ) RESEED_TOKEN = compile(r"\") +WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") @@ -380,10 +381,10 @@ def encode_prompt( ) -> List[np.ndarray]: return [ pipe._encode_prompt( - prompt, + remove_tokens(prompt), num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=neg_prompt, + negative_prompt=remove_tokens(neg_prompt), ) for prompt, neg_prompt in prompt_pairs ] @@ -504,3 +505,8 @@ def parse_reseed_group(group) -> Region: def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]: return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group) + + +def remove_tokens(prompt: str) -> str: + remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN) + return remainder