fix(api): enable prompt alternatives for SDXL
This commit is contained in:
parent
09f600ab54
commit
6eb014cec8
|
@ -84,9 +84,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
|
||||
if not params.is_xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
|
|
|
@ -133,9 +133,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
|
||||
if not params.is_xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
|
|
|
@ -17,14 +17,15 @@ LATENT_CHANNELS = 4
|
|||
LATENT_FACTOR = 8
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
|
||||
ANY_TOKEN = compile(r"\<([^\>]*)\>")
|
||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||
REGION_TOKEN = compile(
|
||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
|
||||
)
|
||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||
|
@ -380,10 +381,10 @@ def encode_prompt(
|
|||
) -> List[np.ndarray]:
|
||||
return [
|
||||
pipe._encode_prompt(
|
||||
prompt,
|
||||
remove_tokens(prompt),
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=neg_prompt,
|
||||
negative_prompt=remove_tokens(neg_prompt),
|
||||
)
|
||||
for prompt, neg_prompt in prompt_pairs
|
||||
]
|
||||
|
@ -504,3 +505,8 @@ def parse_reseed_group(group) -> Region:
|
|||
|
||||
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
||||
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
||||
|
||||
|
||||
def remove_tokens(prompt: str) -> str:
|
||||
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN)
|
||||
return remainder
|
||||
|
|
Loading…
Reference in New Issue