fix(api): pass job context and device to upscaling
This commit is contained in:
parent
8a81e8b810
commit
3637f642c6
|
@ -69,7 +69,7 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
server, StageParams(), params, image, upscale=upscale)
|
||||
job, server, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
@ -109,7 +109,7 @@ def run_img2img_pipeline(
|
|||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
server, StageParams(), params, image, upscale=upscale)
|
||||
job, server, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
size = Size(*source_image.size)
|
||||
|
@ -141,7 +141,6 @@ def run_inpaint_pipeline(
|
|||
# progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
# TODO: pass device, progress
|
||||
image = upscale_outpaint(
|
||||
server,
|
||||
stage,
|
||||
|
@ -162,7 +161,7 @@ def run_inpaint_pipeline(
|
|||
'output image size does not match source, skipping post-blend')
|
||||
|
||||
image = run_upscale_correction(
|
||||
server, stage, params, image, upscale=upscale)
|
||||
job, server, stage, params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
|
@ -186,9 +185,8 @@ def run_upscale_pipeline(
|
|||
# progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
# TODO: pass device, progress
|
||||
image = run_upscale_correction(
|
||||
server, stage, params, source_image, upscale=upscale)
|
||||
job, server, stage, params, source_image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
|
|
@ -41,6 +41,7 @@ from .chain import (
|
|||
ChainPipeline,
|
||||
)
|
||||
from .device_pool import (
|
||||
DeviceParams,
|
||||
DevicePoolExecutor,
|
||||
)
|
||||
from .diffusion.run import (
|
||||
|
@ -168,14 +169,23 @@ def url_from_rule(rule) -> str:
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
||||
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
# platform stuff
|
||||
device_name = request.args.get('platform', available_platforms[0].device)
|
||||
device = None
|
||||
|
||||
for platform in available_platforms:
|
||||
if platform.device == device_name:
|
||||
device = available_platforms[0]
|
||||
|
||||
if device is None:
|
||||
raise Exception('unknown device')
|
||||
|
||||
# pipeline stuff
|
||||
model = get_not_empty(request.args, 'model', get_config_value('model'))
|
||||
model_path = get_model_path(model)
|
||||
provider = get_from_map(request.args, 'platform',
|
||||
platform_providers, get_config_value('platform'))
|
||||
scheduler = get_from_map(request.args, 'scheduler',
|
||||
pipeline_schedulers, get_config_value('scheduler'))
|
||||
|
||||
|
@ -213,12 +223,12 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
|||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||
user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt)
|
||||
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
||||
|
||||
params = ImageParams(model_path, provider, scheduler, prompt,
|
||||
params = ImageParams(model_path, scheduler, prompt,
|
||||
negative_prompt, cfg, steps, seed)
|
||||
size = Size(width, height)
|
||||
return (params, size)
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def border_from_request() -> Border:
|
||||
|
|
|
@ -7,6 +7,9 @@ from .chain import (
|
|||
upscale_resrgan,
|
||||
ChainPipeline,
|
||||
)
|
||||
from .device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from .params import (
|
||||
ImageParams,
|
||||
SizeChart,
|
||||
|
@ -21,7 +24,8 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def run_upscale_correction(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
image: Image.Image,
|
||||
|
@ -51,4 +55,4 @@ def run_upscale_correction(
|
|||
outscale=1)
|
||||
chain.append((correct_gfpgan, stage, None))
|
||||
|
||||
return chain(ctx, params, image, prompt=params.prompt, upscale=upscale)
|
||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
"type": "upscale-outpaint",
|
||||
"params": {
|
||||
"border": 256,
|
||||
"model": "stable-diffusion-onnx-v1-inpainting",
|
||||
"prompt": "a magical wizard in a robe fighting a dragon"
|
||||
}
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue