1
0
Fork 0

fix(api): restore separate upscale and correction stages

This commit is contained in:
Sean Sube 2023-02-18 11:59:39 -06:00
parent 118695d68c
commit f534fbb92c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 12 additions and 12 deletions

View File

@ -97,21 +97,21 @@ def optimize_pipeline(
try:
pipe.enable_attention_slicing()
except Exception as e:
logger.warning("error enabling attention slicing: %s", e)
logger.warning("error while enabling attention slicing: %s", e)
if "vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.warning("error enabling VAE slicing: %s", e)
logger.warning("error while enabling VAE slicing: %s", e)
if "sequential-cpu-offload" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
try:
pipe.enable_sequential_cpu_offload()
except Exception as e:
logger.warning("error enabling sequential CPU offload: %s", e)
logger.warning("error while enabling sequential CPU offload: %s", e)
elif "model-cpu-offload" in server.optimizations:
# TODO: check for accelerate
@ -119,7 +119,7 @@ def optimize_pipeline(
try:
pipe.enable_model_cpu_offload()
except Exception as e:
logger.warning("error enabling model CPU offload: %s", e)
logger.warning("error while enabling model CPU offload: %s", e)
if "memory-efficient-attention" in server.optimizations:
@ -128,7 +128,7 @@ def optimize_pipeline(
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning("error enabling memory efficient attention: %s", e)
logger.warning("error while enabling memory efficient attention: %s", e)
def load_pipeline(

View File

@ -34,6 +34,7 @@ def run_upscale_correction(
chain = ChainPipeline()
upscale_stage = None
if upscale.scale > 1:
if "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
@ -42,23 +43,22 @@ def run_upscale_correction(
upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
upscale_stage = None
correct_stage = None
if upscale.faces:
face_stage = StageParams(
face_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_stage, None)
correct_stage = (correct_codeformer, face_params, None)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_stage, None)
correct_stage = (correct_gfpgan, face_params, None)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
correct_stage = None
if upscale.upscale_order == "correction-both":
chain.append(correct_stage)