1
0
Fork 0

fix(api): pass correct outscale to highres stages

This commit is contained in:
Sean Sube 2023-11-25 12:25:16 -06:00
parent b1328fdfdb
commit 6ecdae44a2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 22 additions and 12 deletions

View File

@ -57,7 +57,7 @@ def stage_highres(
chain.stage(
BlendImg2ImgStage(),
stage,
stage.with_args(outscale=1),
overlap=params.vae_overlap,
prompt_index=prompt_index + i,
strength=highres.strength,

View File

@ -156,7 +156,6 @@ def blend_tiles(
value = np.zeros(scaled_size)
for left, top, tile_image in tiles:
# TODO: histogram equalization
equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0])

View File

@ -64,21 +64,20 @@ def run_txt2img_pipeline(
# apply upscaling and correction, before highres
highres_size = params.unet_tile
stage = StageParams(tile_size=highres_size)
if params.is_panorama():
if server.has_feature("panorama-highres"):
# run the whole highres pass with one panorama call
highres_size = tile_size * highres.scale
chain.stage(
BlendDenoiseStage(),
stage,
StageParams(tile_size=highres_size),
)
if server.has_feature("panorama-highres"):
highres_size = tile_size * highres.scale
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=first_upscale,
@ -86,7 +85,7 @@ def run_txt2img_pipeline(
# apply highres
stage_highres(
stage,
StageParams(outscale=highres.scale, tile_size=highres_size),
params,
highres,
upscale,
@ -96,7 +95,7 @@ def run_txt2img_pipeline(
# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=after_upscale,

View File

@ -369,6 +369,17 @@ class StageParams:
self.tile_order = tile_order
self.tile_size = tile_size
def with_args(
self,
**kwargs,
):
return StageParams(
name=kwargs.get("name", self.name),
outscale=kwargs.get("outscale", self.outscale),
tile_order=kwargs.get("tile_order", self.tile_order),
tile_size=kwargs.get("tile_size", self.tile_size),
)
class UpscaleParams:
def __init__(

View File

@ -149,10 +149,11 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3.0,
1,
1,
unet_tile=256,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test"),
UpscaleParams("test", scale=2, outscale=2),
HighresParams(True, 2, 0, 0),
)