1
0
Fork 0

feat(api): add option to use any available platform

This commit is contained in:
Sean Sube 2023-02-11 15:41:42 -06:00
parent 917f5bedd6
commit ea3b065d80
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 37 additions and 17 deletions

View File

@ -156,7 +156,13 @@ class DevicePoolExecutor:
logger.warn("checking status for unknown key: %s", key) logger.warn("checking status for unknown key: %s", key)
return (None, 0) return (None, 0)
def get_next_device(self): def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in self.devices:
if self.devices[i].device == needs_device.device:
return i
# use the first/default device if there are no jobs # use the first/default device if there are no jobs
if len(self.jobs) == 0: if len(self.jobs) == 0:
return 0 return 0
@ -179,8 +185,8 @@ class DevicePoolExecutor:
def prune(self): def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()] self.jobs[:] = [job for job in self.jobs if job.future.done()]
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None: def submit(self, key: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs) -> None:
device = self.get_next_device() device = self.get_next_device(needs_device=needs_device)
logger.info("assigning job %s to device %s", key, device) logger.info("assigning job %s to device %s", key, device)
context = JobContext(key, self.devices, device_index=device) context = JobContext(key, self.devices, device_index=device)

View File

@ -157,16 +157,13 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr user = request.remote_addr
# platform stuff # platform stuff
device_name = request.args.get("platform", available_platforms[0].device)
device = None device = None
device_name = request.args.get("platform")
for platform in available_platforms: if device_name is not None and device_name != "any":
if platform.device == device_name: for platform in available_platforms:
device = available_platforms[0] if platform.device == device_name:
device = available_platforms[0]
if device is None:
logger.warn("unknown platform: %s", device_name)
device = available_platforms[0]
# pipeline stuff # pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true" lpw = get_not_empty(request.args, "lpw", "false") == "true"
@ -223,7 +220,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
steps, steps,
scheduler.__name__, scheduler.__name__,
model_path, model_path,
device.provider, device or "any device",
width, width,
height, height,
cfg, cfg,
@ -368,16 +365,21 @@ def load_platforms():
) )
# make sure CPU is last on the list # make sure CPU is last on the list
def cpu_last(a: DeviceParams, b: DeviceParams): def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
if a.device == "cpu" and b.device == "cpu": if a.device == b.device:
return 0 return 0
# any should be first, if it's available
if a.device == "any":
return -1
# cpu should be last, if it's available
if a.device == "cpu": if a.device == "cpu":
return 1 return 1
return -1 return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) available_platforms = sorted(available_platforms, key=cmp_to_key(any_first_cpu_last))
logger.info( logger.info(
"available acceleration platforms: %s", "available acceleration platforms: %s",
@ -521,6 +523,7 @@ def img2img():
upscale, upscale,
source_image, source_image,
strength, strength,
needs_device=device,
) )
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@ -535,7 +538,14 @@ def txt2img():
logger.info("txt2img job queued for: %s", output) logger.info("txt2img job queued for: %s", output)
executor.submit( executor.submit(
output, run_txt2img_pipeline, context, params, size, output, upscale output,
run_txt2img_pipeline,
context,
params,
size,
output,
upscale,
needs_device=device,
) )
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@ -605,6 +615,7 @@ def inpaint():
mask_filter, mask_filter,
strength, strength,
fill_color, fill_color,
needs_device=device,
) )
return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
@ -634,6 +645,7 @@ def upscale():
output, output,
upscale, upscale,
source_image, source_image,
needs_device=device,
) )
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@ -711,7 +723,7 @@ def chain():
# build and run chain pipeline # build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height)) empty_source = Image.new("RGB", (size.width, size.height))
executor.submit( executor.submit(
output, pipeline, context, params, empty_source, output=output, size=size output, pipeline, context, params, empty_source, output=output, size=size, needs_device=device
) )
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))

View File

@ -32,6 +32,8 @@ export const MODEL_LABELS = {
export const PLATFORM_LABELS: Record<string, string> = { export const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU', amd: 'AMD GPU',
// eslint-disable-next-line id-blacklist
any: 'Any Platform',
cpu: 'CPU', cpu: 'CPU',
cuda: 'CUDA', cuda: 'CUDA',
directml: 'DirectML', directml: 'DirectML',