1
0
Fork 0

fix(api): enable prompt alternatives for SDXL

This commit is contained in:
Sean Sube 2023-11-12 14:12:28 -06:00
parent 09f600ab54
commit 6eb014cec8
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 11 additions and 9 deletions

View File

@ -84,9 +84,7 @@ class BlendImg2ImgStage(BaseStage):
prompt_embeds = encode_prompt( prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg() pipe, prompt_pairs, params.batch, params.do_cfg()
) )
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(

View File

@ -133,9 +133,7 @@ class SourceTxt2ImgStage(BaseStage):
prompt_embeds = encode_prompt( prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg() pipe, prompt_pairs, params.batch, params.do_cfg()
) )
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(

View File

@ -17,14 +17,15 @@ LATENT_CHANNELS = 4
LATENT_FACTOR = 8 LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77 MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>") CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>") INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile( REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>" r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
) )
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>") RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -380,10 +381,10 @@ def encode_prompt(
) -> List[np.ndarray]: ) -> List[np.ndarray]:
return [ return [
pipe._encode_prompt( pipe._encode_prompt(
prompt, remove_tokens(prompt),
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=neg_prompt, negative_prompt=remove_tokens(neg_prompt),
) )
for prompt, neg_prompt in prompt_pairs for prompt, neg_prompt in prompt_pairs
] ]
@ -504,3 +505,8 @@ def parse_reseed_group(group) -> Region:
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]: def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group) return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
def remove_tokens(prompt: str) -> str:
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN)
return remainder