1
0
Fork 0

feat(api): allow a different prompt for each highres stage

This commit is contained in:
Sean Sube 2023-08-29 20:53:16 -05:00
parent 6b31075616
commit 8ce09d307b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 20 additions and 7 deletions

View File

@ -29,12 +29,14 @@ class BlendImg2ImgStage(BaseStage):
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
prompt_index: Optional[int] = None,
**kwargs,
) -> List[Image.Image]:
params = params.with_args(**kwargs)
# highres hax
params = params.with_args(prompt=slice_prompt(params.prompt, 1))
# multi-stage prompting
if prompt_index is not None:
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt

View File

@ -16,6 +16,7 @@ def stage_highres(
highres: HighresParams,
upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None,
prompt_index: Optional[int] = None,
) -> ChainPipeline:
logger.info("staging highres pipeline at %s", highres.scale)
@ -30,7 +31,7 @@ def stage_highres(
logger.debug("no highres iterations, skipping")
return chain
for _i in range(highres.iterations):
for i in range(highres.iterations):
if highres.method == "upscale":
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
@ -58,6 +59,7 @@ def stage_highres(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
prompt_index=prompt_index + i,
strength=highres.strength,
)

View File

@ -36,13 +36,15 @@ class SourceTxt2ImgStage(BaseStage):
size: Size,
callback: Optional[ProgressCallback] = None,
latents: Optional[np.ndarray] = None,
prompt_index: Optional[int] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
size = size.with_args(**kwargs)
# highres hax
params = params.with_args(prompt=slice_prompt(params.prompt, 0))
# multi-stage prompting
if prompt_index is not None:
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt

View File

@ -56,6 +56,7 @@ def run_txt2img_pipeline(
tile_size=tile_size,
),
size=size,
prompt_index=0,
overlap=params.overlap,
)
@ -66,8 +67,8 @@ def run_txt2img_pipeline(
stage_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
upscale=first_upscale,
)
# apply highres
@ -77,14 +78,15 @@ def run_txt2img_pipeline(
highres,
upscale,
chain=chain,
prompt_index=1,
)
# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
upscale=after_upscale,
)
# run and save
@ -141,6 +143,7 @@ def run_img2img_pipeline(
chain.stage(
BlendImg2ImgStage(),
stage,
prompt_index=0,
strength=strength,
overlap=params.overlap,
)
@ -170,6 +173,7 @@ def run_img2img_pipeline(
highres,
upscale,
chain=chain,
prompt_index=1,
)
# apply upscaling and correction, after highres
@ -328,6 +332,7 @@ def run_inpaint_pipeline(
mask_filter=mask_filter,
noise_source=noise_source,
overlap=params.overlap,
prompt_index=0,
)
# apply upscaling and correction, before highres
@ -347,6 +352,7 @@ def run_inpaint_pipeline(
highres,
upscale,
chain=chain,
prompt_index=1,
)
# apply upscaling and correction
@ -422,6 +428,7 @@ def run_upscale_pipeline(
highres,
upscale,
chain=chain,
prompt_index=0,
)
# apply upscaling and correction, after highres