fix(api): enable prompt alternatives for SDXL
This commit is contained in:
parent
09f600ab54
commit
6eb014cec8
|
@ -84,9 +84,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
prompt_embeds = encode_prompt(
|
prompt_embeds = encode_prompt(
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
)
|
)
|
||||||
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
if not params.is_xl():
|
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
result = pipe(
|
||||||
|
|
|
@ -133,9 +133,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
prompt_embeds = encode_prompt(
|
prompt_embeds = encode_prompt(
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
)
|
)
|
||||||
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
if not params.is_xl():
|
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
result = pipe(
|
||||||
|
|
|
@ -17,14 +17,15 @@ LATENT_CHANNELS = 4
|
||||||
LATENT_FACTOR = 8
|
LATENT_FACTOR = 8
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
|
|
||||||
|
ANY_TOKEN = compile(r"\<([^\>]*)\>")
|
||||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
|
||||||
REGION_TOKEN = compile(
|
REGION_TOKEN = compile(
|
||||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
|
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
|
||||||
)
|
)
|
||||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||||
|
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
@ -380,10 +381,10 @@ def encode_prompt(
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
return [
|
return [
|
||||||
pipe._encode_prompt(
|
pipe._encode_prompt(
|
||||||
prompt,
|
remove_tokens(prompt),
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
negative_prompt=neg_prompt,
|
negative_prompt=remove_tokens(neg_prompt),
|
||||||
)
|
)
|
||||||
for prompt, neg_prompt in prompt_pairs
|
for prompt, neg_prompt in prompt_pairs
|
||||||
]
|
]
|
||||||
|
@ -504,3 +505,8 @@ def parse_reseed_group(group) -> Region:
|
||||||
|
|
||||||
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
||||||
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tokens(prompt: str) -> str:
|
||||||
|
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN)
|
||||||
|
return remainder
|
||||||
|
|
Loading…
Reference in New Issue