lint(api): consolidate upscale/correction logic
This commit is contained in:
parent
d1ed5c48e8
commit
483b8e3f19
|
@ -1,3 +1,10 @@
|
||||||
|
from .diffusion import (
|
||||||
|
get_latents_from_seed,
|
||||||
|
load_pipeline,
|
||||||
|
run_img2img_pipeline,
|
||||||
|
run_inpaint_pipeline,
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
)
|
||||||
from .image import (
|
from .image import (
|
||||||
expand_image,
|
expand_image,
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
|
@ -10,15 +17,9 @@ from .image import (
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .pipeline import (
|
|
||||||
get_latents_from_seed,
|
|
||||||
load_pipeline,
|
|
||||||
run_img2img_pipeline,
|
|
||||||
run_inpaint_pipeline,
|
|
||||||
run_txt2img_pipeline,
|
|
||||||
)
|
|
||||||
from .upscale import (
|
from .upscale import (
|
||||||
make_resrgan,
|
make_resrgan,
|
||||||
|
run_upscale_pipeline,
|
||||||
upscale_gfpgan,
|
upscale_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
|
@ -26,7 +27,9 @@ from .upscale import (
|
||||||
from .utils import (
|
from .utils import (
|
||||||
get_and_clamp_float,
|
get_and_clamp_float,
|
||||||
get_and_clamp_int,
|
get_and_clamp_int,
|
||||||
|
get_from_list,
|
||||||
get_from_map,
|
get_from_map,
|
||||||
|
get_not_empty,
|
||||||
safer_join,
|
safer_join,
|
||||||
BaseParams,
|
BaseParams,
|
||||||
Border,
|
Border,
|
||||||
|
|
|
@ -61,7 +61,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
print('loading different pipeline')
|
print('loading new pipeline')
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
@ -77,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
print('changing pipeline scheduler')
|
print('loading new scheduler')
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model, subfolder='scheduler')
|
model, subfolder='scheduler')
|
||||||
|
|
||||||
|
@ -117,9 +117,7 @@ def run_txt2img_pipeline(
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
|
image = run_upscale_pipeline(ctx, upscale, image)
|
||||||
if upscale.faces or upscale.scale > 1:
|
|
||||||
image = upscale_resrgan(ctx, upscale, image)
|
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
@ -153,9 +151,7 @@ def run_img2img_pipeline(
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
|
image = run_upscale_pipeline(ctx, upscale, image)
|
||||||
if upscale.faces or upscale.scale > 1:
|
|
||||||
image = upscale_resrgan(ctx, upscale, image)
|
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
@ -219,8 +215,7 @@ def run_inpaint_pipeline(
|
||||||
else:
|
else:
|
||||||
print('output image size does not match source, skipping post-blend')
|
print('output image size does not match source, skipping post-blend')
|
||||||
|
|
||||||
if upscale.faces or upscale.scale > 1:
|
image = run_upscale_pipeline(ctx, upscale, image)
|
||||||
image = upscale_resrgan(ctx, upscale, image)
|
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
|
@ -1,5 +1,4 @@
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
# schedulers
|
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
@ -23,6 +22,12 @@ from onnxruntime import get_available_providers
|
||||||
from os import makedirs, path, scandir
|
from os import makedirs, path, scandir
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .diffusion import (
|
||||||
|
run_img2img_pipeline,
|
||||||
|
run_inpaint_pipeline,
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
run_upscale_pipeline,
|
||||||
|
)
|
||||||
from .image import (
|
from .image import (
|
||||||
# mask filters
|
# mask filters
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
|
@ -36,12 +41,6 @@ from .image import (
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .pipeline import (
|
|
||||||
run_img2img_pipeline,
|
|
||||||
run_inpaint_pipeline,
|
|
||||||
run_txt2img_pipeline,
|
|
||||||
run_upscale_pipeline,
|
|
||||||
)
|
|
||||||
from .upscale import (
|
from .upscale import (
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
|
|
|
@ -188,3 +188,15 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
||||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
|
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||||
|
print('running upscale pipeline')
|
||||||
|
|
||||||
|
if params.scale > 1:
|
||||||
|
image = upscale_resrgan(ctx, params, image)
|
||||||
|
|
||||||
|
if params.faces:
|
||||||
|
image = upscale_gfpgan(ctx, params, image)
|
||||||
|
|
||||||
|
return image
|
Loading…
Reference in New Issue