lint(api): consolidate upscale/correction logic
This commit is contained in:
parent
d1ed5c48e8
commit
483b8e3f19
|
@ -1,3 +1,10 @@
|
|||
from .diffusion import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
)
|
||||
from .image import (
|
||||
expand_image,
|
||||
mask_filter_gaussian_multiply,
|
||||
|
@ -10,15 +17,9 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .pipeline import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
)
|
||||
from .upscale import (
|
||||
make_resrgan,
|
||||
run_upscale_pipeline,
|
||||
upscale_gfpgan,
|
||||
upscale_resrgan,
|
||||
UpscaleParams,
|
||||
|
@ -26,7 +27,9 @@ from .upscale import (
|
|||
from .utils import (
|
||||
get_and_clamp_float,
|
||||
get_and_clamp_int,
|
||||
get_from_list,
|
||||
get_from_map,
|
||||
get_not_empty,
|
||||
safer_join,
|
||||
BaseParams,
|
||||
Border,
|
||||
|
|
|
@ -61,7 +61,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print('loading different pipeline')
|
||||
print('loading new pipeline')
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
provider=provider,
|
||||
|
@ -77,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
print('loading new scheduler')
|
||||
scheduler = scheduler.from_pretrained(
|
||||
model, subfolder='scheduler')
|
||||
|
||||
|
@ -117,9 +117,7 @@ def run_txt2img_pipeline(
|
|||
num_inference_steps=params.steps,
|
||||
)
|
||||
image = result.images[0]
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -153,9 +151,7 @@ def run_img2img_pipeline(
|
|||
strength=strength,
|
||||
)
|
||||
image = result.images[0]
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -219,8 +215,7 @@ def run_inpaint_pipeline(
|
|||
else:
|
||||
print('output image size does not match source, skipping post-blend')
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
|
@ -1,5 +1,4 @@
|
|||
from diffusers import (
|
||||
# schedulers
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
|
@ -23,6 +22,12 @@ from onnxruntime import get_available_providers
|
|||
from os import makedirs, path, scandir
|
||||
from typing import Tuple
|
||||
|
||||
from .diffusion import (
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from .image import (
|
||||
# mask filters
|
||||
mask_filter_gaussian_multiply,
|
||||
|
@ -36,12 +41,6 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .pipeline import (
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from .upscale import (
|
||||
UpscaleParams,
|
||||
)
|
||||
|
|
|
@ -188,3 +188,15 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
|||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
print('running upscale pipeline')
|
||||
|
||||
if params.scale > 1:
|
||||
image = upscale_resrgan(ctx, params, image)
|
||||
|
||||
if params.faces:
|
||||
image = upscale_gfpgan(ctx, params, image)
|
||||
|
||||
return image
|
Loading…
Reference in New Issue