fix(api): pass job context and device to upscaling
This commit is contained in:
parent
8a81e8b810
commit
3637f642c6
|
@ -69,7 +69,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
server, StageParams(), params, image, upscale=upscale)
|
job, server, StageParams(), params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
@ -109,7 +109,7 @@ def run_img2img_pipeline(
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
server, StageParams(), params, image, upscale=upscale)
|
job, server, StageParams(), params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
size = Size(*source_image.size)
|
size = Size(*source_image.size)
|
||||||
|
@ -141,7 +141,6 @@ def run_inpaint_pipeline(
|
||||||
# progress = job.get_progress_callback()
|
# progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
# TODO: pass device, progress
|
|
||||||
image = upscale_outpaint(
|
image = upscale_outpaint(
|
||||||
server,
|
server,
|
||||||
stage,
|
stage,
|
||||||
|
@ -162,7 +161,7 @@ def run_inpaint_pipeline(
|
||||||
'output image size does not match source, skipping post-blend')
|
'output image size does not match source, skipping post-blend')
|
||||||
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
server, stage, params, image, upscale=upscale)
|
job, server, stage, params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||||
|
@ -186,9 +185,8 @@ def run_upscale_pipeline(
|
||||||
# progress = job.get_progress_callback()
|
# progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
# TODO: pass device, progress
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
server, stage, params, source_image, upscale=upscale)
|
job, server, stage, params, source_image, upscale=upscale)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale)
|
save_params(server, output, params, size, upscale=upscale)
|
||||||
|
|
|
@ -41,6 +41,7 @@ from .chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
)
|
)
|
||||||
from .device_pool import (
|
from .device_pool import (
|
||||||
|
DeviceParams,
|
||||||
DevicePoolExecutor,
|
DevicePoolExecutor,
|
||||||
)
|
)
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
|
@ -168,14 +169,23 @@ def url_from_rule(rule) -> str:
|
||||||
return url_for(rule.endpoint, **options)
|
return url_for(rule.endpoint, **options)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
|
# platform stuff
|
||||||
|
device_name = request.args.get('platform', available_platforms[0].device)
|
||||||
|
device = None
|
||||||
|
|
||||||
|
for platform in available_platforms:
|
||||||
|
if platform.device == device_name:
|
||||||
|
device = available_platforms[0]
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
raise Exception('unknown device')
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
model = get_not_empty(request.args, 'model', get_config_value('model'))
|
model = get_not_empty(request.args, 'model', get_config_value('model'))
|
||||||
model_path = get_model_path(model)
|
model_path = get_model_path(model)
|
||||||
provider = get_from_map(request.args, 'platform',
|
|
||||||
platform_providers, get_config_value('platform'))
|
|
||||||
scheduler = get_from_map(request.args, 'scheduler',
|
scheduler = get_from_map(request.args, 'scheduler',
|
||||||
pipeline_schedulers, get_config_value('scheduler'))
|
pipeline_schedulers, get_config_value('scheduler'))
|
||||||
|
|
||||||
|
@ -213,12 +223,12 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||||
user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt)
|
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
||||||
|
|
||||||
params = ImageParams(model_path, provider, scheduler, prompt,
|
params = ImageParams(model_path, scheduler, prompt,
|
||||||
negative_prompt, cfg, steps, seed)
|
negative_prompt, cfg, steps, seed)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (params, size)
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
def border_from_request() -> Border:
|
def border_from_request() -> Border:
|
||||||
|
|
|
@ -7,6 +7,9 @@ from .chain import (
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
)
|
)
|
||||||
|
from .device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from .params import (
|
from .params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
SizeChart,
|
SizeChart,
|
||||||
|
@ -21,7 +24,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_correction(
|
def run_upscale_correction(
|
||||||
ctx: ServerContext,
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
|
@ -51,4 +55,4 @@ def run_upscale_correction(
|
||||||
outscale=1)
|
outscale=1)
|
||||||
chain.append((correct_gfpgan, stage, None))
|
chain.append((correct_gfpgan, stage, None))
|
||||||
|
|
||||||
return chain(ctx, params, image, prompt=params.prompt, upscale=upscale)
|
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
"type": "upscale-outpaint",
|
"type": "upscale-outpaint",
|
||||||
"params": {
|
"params": {
|
||||||
"border": 256,
|
"border": 256,
|
||||||
|
"model": "stable-diffusion-onnx-v1-inpainting",
|
||||||
"prompt": "a magical wizard in a robe fighting a dragon"
|
"prompt": "a magical wizard in a robe fighting a dragon"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue