From ea3b065d8034af2c03ba70b2b839336e1ba08688 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 11 Feb 2023 15:41:42 -0600 Subject: [PATCH] feat(api): add option to use any available platform --- api/onnx_web/device_pool.py | 12 ++++++++--- api/onnx_web/serve.py | 40 ++++++++++++++++++++++++------------- gui/src/strings.ts | 2 ++ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index 42faddee..0f22bdb6 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -156,7 +156,13 @@ class DevicePoolExecutor: logger.warn("checking status for unknown key: %s", key) return (None, 0) - def get_next_device(self): + def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: + # respect overrides if possible + if needs_device is not None: + for i in self.devices: + if self.devices[i].device == needs_device.device: + return i + # use the first/default device if there are no jobs if len(self.jobs) == 0: return 0 @@ -179,8 +185,8 @@ class DevicePoolExecutor: def prune(self): self.jobs[:] = [job for job in self.jobs if job.future.done()] - def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None: - device = self.get_next_device() + def submit(self, key: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs) -> None: + device = self.get_next_device(needs_device=needs_device) logger.info("assigning job %s to device %s", key, device) context = JobContext(key, self.devices, device_index=device) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 58eeb1c1..07f60064 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -157,16 +157,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr # platform stuff - device_name = request.args.get("platform", available_platforms[0].device) device = None + device_name = request.args.get("platform") - for platform in available_platforms: - if platform.device == device_name: - device = available_platforms[0] - - if device is None: - logger.warn("unknown platform: %s", device_name) - device = available_platforms[0] + if device_name is not None and device_name != "any": + for platform in available_platforms: + if platform.device == device_name: + device = available_platforms[0] # pipeline stuff lpw = get_not_empty(request.args, "lpw", "false") == "true" @@ -223,7 +220,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: steps, scheduler.__name__, model_path, - device.provider, + device or "any device", width, height, cfg, @@ -368,16 +365,21 @@ def load_platforms(): ) # make sure CPU is last on the list - def cpu_last(a: DeviceParams, b: DeviceParams): - if a.device == "cpu" and b.device == "cpu": + def any_first_cpu_last(a: DeviceParams, b: DeviceParams): + if a.device == b.device: return 0 + # any should be first, if it's available + if a.device == "any": + return -1 + + # cpu should be last, if it's available if a.device == "cpu": return 1 return -1 - available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) + available_platforms = sorted(available_platforms, key=cmp_to_key(any_first_cpu_last)) logger.info( "available acceleration platforms: %s", @@ -521,6 +523,7 @@ def img2img(): upscale, source_image, strength, + needs_device=device, ) return jsonify(json_params(output, params, size, upscale=upscale)) @@ -535,7 +538,14 @@ def txt2img(): logger.info("txt2img job queued for: %s", output) executor.submit( - output, run_txt2img_pipeline, context, params, size, output, upscale + output, + run_txt2img_pipeline, + context, + params, + size, + output, + upscale, + needs_device=device, ) return jsonify(json_params(output, params, size, upscale=upscale)) @@ -605,6 +615,7 @@ def inpaint(): mask_filter, strength, fill_color, + needs_device=device, ) return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) @@ -634,6 +645,7 @@ def upscale(): output, upscale, source_image, + needs_device=device, ) return jsonify(json_params(output, params, size, upscale=upscale)) @@ -711,7 +723,7 @@ def chain(): # build and run chain pipeline empty_source = Image.new("RGB", (size.width, size.height)) executor.submit( - output, pipeline, context, params, empty_source, output=output, size=size + output, pipeline, context, params, empty_source, output=output, size=size, needs_device=device ) return jsonify(json_params(output, params, size)) diff --git a/gui/src/strings.ts b/gui/src/strings.ts index d0fababb..79c7e689 100644 --- a/gui/src/strings.ts +++ b/gui/src/strings.ts @@ -32,6 +32,8 @@ export const MODEL_LABELS = { export const PLATFORM_LABELS: Record = { amd: 'AMD GPU', + // eslint-disable-next-line id-blacklist + any: 'Any Platform', cpu: 'CPU', cuda: 'CUDA', directml: 'DirectML',