1
0
Fork 0

fix(api): get upscale params from request

This commit is contained in:
Sean Sube 2023-01-16 13:12:08 -06:00
parent 120056f878
commit 1f0c19af04
4 changed files with 39 additions and 6 deletions

View File

@ -16,6 +16,7 @@ from .image import (
) )
from .upscale import ( from .upscale import (
upscale_resrgan, upscale_resrgan,
UpscaleParams,
) )
from .utils import ( from .utils import (
safer_join, safer_join,
@ -75,7 +76,8 @@ def run_txt2img_pipeline(
ctx: ServerContext, ctx: ServerContext,
params: BaseParams, params: BaseParams,
size: Size, size: Size,
output: str output: str,
upscale: UpscaleParams
): ):
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
@ -93,7 +95,9 @@ def run_txt2img_pipeline(
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
).images[0] ).images[0]
image = upscale_resrgan(image, ctx.model_path)
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)
@ -105,8 +109,9 @@ def run_img2img_pipeline(
ctx: ServerContext, ctx: ServerContext,
params: BaseParams, params: BaseParams,
output: str, output: str,
upscale: UpscaleParams,
source_image: Image, source_image: Image,
strength: float strength: float,
): ):
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
@ -122,7 +127,9 @@ def run_img2img_pipeline(
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength, strength=strength,
).images[0] ).images[0]
image = upscale_resrgan(image, ctx.model_path)
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)
@ -135,6 +142,7 @@ def run_inpaint_pipeline(
params: BaseParams, params: BaseParams,
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams,
source_image: Image, source_image: Image,
mask_image: Image, mask_image: Image,
expand: Border, expand: Border,
@ -173,6 +181,9 @@ def run_inpaint_pipeline(
width=size.width, width=size.width,
).images[0] ).images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
dest = safer_join(ctx.output_path, output) dest = safer_join(ctx.output_path, output)
image.save(dest) image.save(dest)

View File

@ -39,6 +39,9 @@ from .pipeline import (
run_inpaint_pipeline, run_inpaint_pipeline,
run_txt2img_pipeline, run_txt2img_pipeline,
) )
from .upscale import (
UpscaleParams,
)
from .utils import ( from .utils import (
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
@ -186,6 +189,13 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom) return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
faces = request.args.get('faces', 'false') == 'true'
platform = 'onnx'
return UpscaleParams(scale=scale, faces=faces, platform=platform, denoise=denoise)
def check_paths(): def check_paths():
if not path.exists(model_path): if not path.exists(model_path):
raise RuntimeError('model path must exist') raise RuntimeError('model path must exist')
@ -278,6 +288,7 @@ def img2img():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB') source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
params, size = pipeline_from_request() params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, request.args,
@ -294,7 +305,7 @@ def img2img():
source_image.thumbnail((size.width, size.height)) source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_img2img_pipeline, executor.submit_stored(output, run_img2img_pipeline,
context, params, output, source_image, strength) context, params, output, upscale, source_image, strength)
return jsonify({ return jsonify({
'output': output, 'output': output,
@ -306,6 +317,7 @@ def img2img():
@app.route('/api/txt2img', methods=['POST']) @app.route('/api/txt2img', methods=['POST'])
def txt2img(): def txt2img():
params, size = pipeline_from_request() params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name( output = make_output_name(
'txt2img', 'txt2img',
@ -314,7 +326,7 @@ def txt2img():
print("txt2img output: %s" % (output)) print("txt2img output: %s" % (output))
executor.submit_stored( executor.submit_stored(
output, run_txt2img_pipeline, context, params, size, output) output, run_txt2img_pipeline, context, params, size, output, upscale)
return jsonify({ return jsonify({
'output': output, 'output': output,
@ -333,6 +345,7 @@ def inpaint():
params, size = pipeline_from_request() params, size = pipeline_from_request()
expand = border_from_request() expand = border_from_request()
upscale = upscale_from_request()
mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none') mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none')
noise_source = get_from_map( noise_source = get_from_map(
@ -362,6 +375,7 @@ def inpaint():
params, params,
size, size,
output, output,
upscale,
source_image, source_image,
mask_image, mask_image,
expand, expand,

View File

@ -129,6 +129,8 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image: def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image:
print('upscaling image with Real ESRGAN', params)
image = np.array(source_image) image = np.array(source_image)
upsampler = make_resrgan(ctx.model_path) upsampler = make_resrgan(ctx.model_path)
@ -142,6 +144,8 @@ def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscalePara
def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image: def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN')
if upsampler is None: if upsampler is None:
upsampler = make_resrgan(ctx.model_path, 512) upsampler = make_resrgan(ctx.model_path, 512)

View File

@ -17,8 +17,10 @@
"CUDA", "CUDA",
"ddim", "ddim",
"ddpm", "ddpm",
"denoise",
"directml", "directml",
"ftfy", "ftfy",
"gfpgan",
"Heun", "Heun",
"huggingface", "huggingface",
"Inpaint", "Inpaint",
@ -35,6 +37,7 @@
"pndm", "pndm",
"pretrained", "pretrained",
"protobuf", "protobuf",
"resrgan",
"runwayml", "runwayml",
"scandir", "scandir",
"scipy", "scipy",
@ -42,6 +45,7 @@
"spacy", "spacy",
"spinalcase", "spinalcase",
"stringcase", "stringcase",
"upsampler",
"venv", "venv",
"virtualenv", "virtualenv",
"zustand" "zustand"