1
0
Fork 0

fix(api): turn alternatives back off for SDXL

This commit is contained in:
Sean Sube 2023-11-12 14:23:02 -06:00
parent 6eb014cec8
commit 3ffbc00390
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 23 additions and 19 deletions

View File

@ -81,10 +81,11 @@ class BlendImg2ImgStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt( if not params.is_xl():
pipe, prompt_pairs, params.batch, params.do_cfg() prompt_embeds = encode_prompt(
) pipe, prompt_pairs, params.batch, params.do_cfg()
pipe.unet.set_prompts(prompt_embeds) )
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(

View File

@ -130,10 +130,11 @@ class SourceTxt2ImgStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt( if not params.is_xl():
pipe, prompt_pairs, params.batch, params.do_cfg() prompt_embeds = encode_prompt(
) pipe, prompt_pairs, params.batch, params.do_cfg()
pipe.unet.set_prompts(prompt_embeds) )
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(

View File

@ -99,10 +99,11 @@ class UpscaleOutpaintStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt( if not params.is_xl():
pipe, prompt_pairs, params.batch, params.do_cfg() prompt_embeds = encode_prompt(
) pipe, prompt_pairs, params.batch, params.do_cfg()
pipe.unet.set_prompts(prompt_embeds) )
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(

View File

@ -48,13 +48,14 @@ class UpscaleStableDiffusionStage(BaseStage):
) )
generator = torch.manual_seed(params.seed) generator = torch.manual_seed(params.seed)
prompt_embeds = encode_prompt( if not params.is_xl():
pipeline, prompt_embeds = encode_prompt(
prompt_pairs, pipeline,
num_images_per_prompt=params.batch, prompt_pairs,
do_classifier_free_guidance=params.do_cfg(), num_images_per_prompt=params.batch,
) do_classifier_free_guidance=params.do_cfg(),
pipeline.unet.set_prompts(prompt_embeds) )
pipeline.unet.set_prompts(prompt_embeds)
outputs = [] outputs = []
for source in sources: for source in sources: