1
0
Fork 0

fix(api): turn alternatives back off for SDXL

This commit is contained in:
Sean Sube 2023-11-12 14:23:02 -06:00
parent 6eb014cec8
commit 3ffbc00390
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 23 additions and 19 deletions

View File

@ -81,10 +81,11 @@ class BlendImg2ImgStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -130,10 +130,11 @@ class SourceTxt2ImgStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -99,10 +99,11 @@ class UpscaleOutpaintStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(

View File

@ -48,13 +48,14 @@ class UpscaleStableDiffusionStage(BaseStage):
)
generator = torch.manual_seed(params.seed)
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)
outputs = []
for source in sources: