1
0
Fork 0

lint(api): consolidate upscale/correction logic

This commit is contained in:
Sean Sube 2023-01-25 20:31:39 -06:00
parent d1ed5c48e8
commit 483b8e3f19
4 changed files with 33 additions and 24 deletions

View File

@ -1,3 +1,10 @@
from .diffusion import (
get_latents_from_seed,
load_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .image import (
expand_image,
mask_filter_gaussian_multiply,
@ -10,15 +17,9 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .pipeline import (
get_latents_from_seed,
load_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .upscale import (
make_resrgan,
run_upscale_pipeline,
upscale_gfpgan,
upscale_resrgan,
UpscaleParams,
@ -26,7 +27,9 @@ from .upscale import (
from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
get_not_empty,
safer_join,
BaseParams,
Border,

View File

@ -61,7 +61,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
gc.collect()
torch.cuda.empty_cache()
print('loading different pipeline')
print('loading new pipeline')
pipe = pipeline.from_pretrained(
model,
provider=provider,
@ -77,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
print('loading new scheduler')
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')
@ -117,9 +117,7 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps,
)
image = result.images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
image = run_upscale_pipeline(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)
@ -153,9 +151,7 @@ def run_img2img_pipeline(
strength=strength,
)
image = result.images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
image = run_upscale_pipeline(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)
@ -219,8 +215,7 @@ def run_inpaint_pipeline(
else:
print('output image size does not match source, skipping post-blend')
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
image = run_upscale_pipeline(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)

View File

@ -1,5 +1,4 @@
from diffusers import (
# schedulers
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
@ -23,6 +22,12 @@ from onnxruntime import get_available_providers
from os import makedirs, path, scandir
from typing import Tuple
from .diffusion import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .image import (
# mask filters
mask_filter_gaussian_multiply,
@ -36,12 +41,6 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .pipeline import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .upscale import (
UpscaleParams,
)

View File

@ -188,3 +188,15 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
return output
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
print('running upscale pipeline')
if params.scale > 1:
image = upscale_resrgan(ctx, params, image)
if params.faces:
image = upscale_gfpgan(ctx, params, image)
return image