1
0
Fork 0

fix(api): pass job context and device to upscaling

This commit is contained in:
Sean Sube 2023-02-04 14:52:23 -06:00
parent 8a81e8b810
commit 3637f642c6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 27 additions and 14 deletions

View File

@ -69,7 +69,7 @@ def run_txt2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
@ -109,7 +109,7 @@ def run_img2img_pipeline(
)
image = result.images[0]
image = run_upscale_correction(
server, StageParams(), params, image, upscale=upscale)
job, server, StageParams(), params, image, upscale=upscale)
dest = save_image(server, output, image)
size = Size(*source_image.size)
@ -141,7 +141,6 @@ def run_inpaint_pipeline(
# progress = job.get_progress_callback()
stage = StageParams()
# TODO: pass device, progress
image = upscale_outpaint(
server,
stage,
@ -162,7 +161,7 @@ def run_inpaint_pipeline(
'output image size does not match source, skipping post-blend')
image = run_upscale_correction(
server, stage, params, image, upscale=upscale)
job, server, stage, params, image, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
@ -186,9 +185,8 @@ def run_upscale_pipeline(
# progress = job.get_progress_callback()
stage = StageParams()
# TODO: pass device, progress
image = run_upscale_correction(
server, stage, params, source_image, upscale=upscale)
job, server, stage, params, source_image, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

View File

@ -41,6 +41,7 @@ from .chain import (
ChainPipeline,
)
from .device_pool import (
DeviceParams,
DevicePoolExecutor,
)
from .diffusion.run import (
@ -168,14 +169,23 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)
def pipeline_from_request() -> Tuple[ImageParams, Size]:
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
# platform stuff
device_name = request.args.get('platform', available_platforms[0].device)
device = None
for platform in available_platforms:
if platform.device == device_name:
device = available_platforms[0]
if device is None:
raise Exception('unknown device')
# pipeline stuff
model = get_not_empty(request.args, 'model', get_config_value('model'))
model_path = get_model_path(model)
provider = get_from_map(request.args, 'platform',
platform_providers, get_config_value('platform'))
scheduler = get_from_map(request.args, 'scheduler',
pipeline_schedulers, get_config_value('scheduler'))
@ -213,12 +223,12 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt)
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
params = ImageParams(model_path, provider, scheduler, prompt,
params = ImageParams(model_path, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)
return (device, params, size)
def border_from_request() -> Border:

View File

@ -7,6 +7,9 @@ from .chain import (
upscale_resrgan,
ChainPipeline,
)
from .device_pool import (
JobContext,
)
from .params import (
ImageParams,
SizeChart,
@ -21,7 +24,8 @@ logger = getLogger(__name__)
def run_upscale_correction(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
@ -51,4 +55,4 @@ def run_upscale_correction(
outscale=1)
chain.append((correct_gfpgan, stage, None))
return chain(ctx, params, image, prompt=params.prompt, upscale=upscale)
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)

View File

@ -12,6 +12,7 @@
"type": "upscale-outpaint",
"params": {
"border": 256,
"model": "stable-diffusion-onnx-v1-inpainting",
"prompt": "a magical wizard in a robe fighting a dragon"
}
},