1
0
Fork 0

lint(api): consolidate upscale/correction logic

This commit is contained in:
Sean Sube 2023-01-25 20:31:39 -06:00
parent d1ed5c48e8
commit 483b8e3f19
4 changed files with 33 additions and 24 deletions

View File

@ -1,3 +1,10 @@
from .diffusion import (
get_latents_from_seed,
load_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .image import ( from .image import (
expand_image, expand_image,
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
@ -10,15 +17,9 @@ from .image import (
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from .pipeline import (
get_latents_from_seed,
load_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .upscale import ( from .upscale import (
make_resrgan, make_resrgan,
run_upscale_pipeline,
upscale_gfpgan, upscale_gfpgan,
upscale_resrgan, upscale_resrgan,
UpscaleParams, UpscaleParams,
@ -26,7 +27,9 @@ from .upscale import (
from .utils import ( from .utils import (
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
get_from_list,
get_from_map, get_from_map,
get_not_empty,
safer_join, safer_join,
BaseParams, BaseParams,
Border, Border,

View File

@ -61,7 +61,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
print('loading different pipeline') print('loading new pipeline')
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
provider=provider, provider=provider,
@ -77,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler') print('loading new scheduler')
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, subfolder='scheduler') model, subfolder='scheduler')
@ -117,9 +117,7 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps, num_inference_steps=params.steps,
) )
image = result.images[0] image = result.images[0]
image = run_upscale_pipeline(ctx, upscale, image)
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)
@ -153,9 +151,7 @@ def run_img2img_pipeline(
strength=strength, strength=strength,
) )
image = result.images[0] image = result.images[0]
image = run_upscale_pipeline(ctx, upscale, image)
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)
@ -219,8 +215,7 @@ def run_inpaint_pipeline(
else: else:
print('output image size does not match source, skipping post-blend') print('output image size does not match source, skipping post-blend')
if upscale.faces or upscale.scale > 1: image = run_upscale_pipeline(ctx, upscale, image)
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)

View File

@ -1,5 +1,4 @@
from diffusers import ( from diffusers import (
# schedulers
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
@ -23,6 +22,12 @@ from onnxruntime import get_available_providers
from os import makedirs, path, scandir from os import makedirs, path, scandir
from typing import Tuple from typing import Tuple
from .diffusion import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .image import ( from .image import (
# mask filters # mask filters
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
@ -36,12 +41,6 @@ from .image import (
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from .pipeline import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .upscale import ( from .upscale import (
UpscaleParams, UpscaleParams,
) )

View File

@ -188,3 +188,15 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength) image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
return output return output
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
print('running upscale pipeline')
if params.scale > 1:
image = upscale_resrgan(ctx, params, image)
if params.faces:
image = upscale_gfpgan(ctx, params, image)
return image