fix(api): pass correct outscale to highres stages
This commit is contained in:
parent
b1328fdfdb
commit
6ecdae44a2
|
@ -57,7 +57,7 @@ def stage_highres(
|
|||
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
stage.with_args(outscale=1),
|
||||
overlap=params.vae_overlap,
|
||||
prompt_index=prompt_index + i,
|
||||
strength=highres.strength,
|
||||
|
|
|
@ -156,7 +156,6 @@ def blend_tiles(
|
|||
value = np.zeros(scaled_size)
|
||||
|
||||
for left, top, tile_image in tiles:
|
||||
# TODO: histogram equalization
|
||||
equalized = np.array(tile_image).astype(np.float32)
|
||||
mask = np.ones_like(equalized[:, :, 0])
|
||||
|
||||
|
|
|
@ -64,21 +64,20 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply upscaling and correction, before highres
|
||||
highres_size = params.unet_tile
|
||||
stage = StageParams(tile_size=highres_size)
|
||||
|
||||
if params.is_panorama():
|
||||
if server.has_feature("panorama-highres"):
|
||||
# run the whole highres pass with one panorama call
|
||||
highres_size = tile_size * highres.scale
|
||||
|
||||
chain.stage(
|
||||
BlendDenoiseStage(),
|
||||
stage,
|
||||
StageParams(tile_size=highres_size),
|
||||
)
|
||||
|
||||
if server.has_feature("panorama-highres"):
|
||||
highres_size = tile_size * highres.scale
|
||||
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
|
||||
params,
|
||||
chain=chain,
|
||||
upscale=first_upscale,
|
||||
|
@ -86,7 +85,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply highres
|
||||
stage_highres(
|
||||
stage,
|
||||
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
|
@ -96,7 +95,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
|
||||
params,
|
||||
chain=chain,
|
||||
upscale=after_upscale,
|
||||
|
|
|
@ -369,6 +369,17 @@ class StageParams:
|
|||
self.tile_order = tile_order
|
||||
self.tile_size = tile_size
|
||||
|
||||
def with_args(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
return StageParams(
|
||||
name=kwargs.get("name", self.name),
|
||||
outscale=kwargs.get("outscale", self.outscale),
|
||||
tile_order=kwargs.get("tile_order", self.tile_order),
|
||||
tile_size=kwargs.get("tile_size", self.tile_size),
|
||||
)
|
||||
|
||||
|
||||
class UpscaleParams:
|
||||
def __init__(
|
||||
|
|
|
@ -149,10 +149,11 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
3.0,
|
||||
1,
|
||||
1,
|
||||
unet_tile=256,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-highres.png"],
|
||||
UpscaleParams("test"),
|
||||
UpscaleParams("test", scale=2, outscale=2),
|
||||
HighresParams(True, 2, 0, 0),
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue