1
0
Fork 0

fix(api): enable prompt alternatives for SDXL

This commit is contained in:
Sean Sube 2023-11-12 14:12:28 -06:00
parent 09f600ab54
commit 6eb014cec8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 11 additions and 9 deletions

View File

@ -84,9 +84,7 @@ class BlendImg2ImgStage(BaseStage):
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -133,9 +133,7 @@ class SourceTxt2ImgStage(BaseStage):
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -17,14 +17,15 @@ LATENT_CHANNELS = 4
LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
)
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -380,10 +381,10 @@ def encode_prompt(
) -> List[np.ndarray]:
return [
pipe._encode_prompt(
prompt,
remove_tokens(prompt),
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=neg_prompt,
negative_prompt=remove_tokens(neg_prompt),
)
for prompt, neg_prompt in prompt_pairs
]
@ -504,3 +505,8 @@ def parse_reseed_group(group) -> Region:
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
def remove_tokens(prompt: str) -> str:
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN)
return remainder