feat(api): add option to use any available platform
This commit is contained in:
parent
917f5bedd6
commit
ea3b065d80
|
@ -156,7 +156,13 @@ class DevicePoolExecutor:
|
|||
logger.warn("checking status for unknown key: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
def get_next_device(self):
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
# respect overrides if possible
|
||||
if needs_device is not None:
|
||||
for i in self.devices:
|
||||
if self.devices[i].device == needs_device.device:
|
||||
return i
|
||||
|
||||
# use the first/default device if there are no jobs
|
||||
if len(self.jobs) == 0:
|
||||
return 0
|
||||
|
@ -179,8 +185,8 @@ class DevicePoolExecutor:
|
|||
def prune(self):
|
||||
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
||||
|
||||
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
||||
device = self.get_next_device()
|
||||
def submit(self, key: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs) -> None:
|
||||
device = self.get_next_device(needs_device=needs_device)
|
||||
logger.info("assigning job %s to device %s", key, device)
|
||||
|
||||
context = JobContext(key, self.devices, device_index=device)
|
||||
|
|
|
@ -157,17 +157,14 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
user = request.remote_addr
|
||||
|
||||
# platform stuff
|
||||
device_name = request.args.get("platform", available_platforms[0].device)
|
||||
device = None
|
||||
device_name = request.args.get("platform")
|
||||
|
||||
if device_name is not None and device_name != "any":
|
||||
for platform in available_platforms:
|
||||
if platform.device == device_name:
|
||||
device = available_platforms[0]
|
||||
|
||||
if device is None:
|
||||
logger.warn("unknown platform: %s", device_name)
|
||||
device = available_platforms[0]
|
||||
|
||||
# pipeline stuff
|
||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
|
@ -223,7 +220,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|||
steps,
|
||||
scheduler.__name__,
|
||||
model_path,
|
||||
device.provider,
|
||||
device or "any device",
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
|
@ -368,16 +365,21 @@ def load_platforms():
|
|||
)
|
||||
|
||||
# make sure CPU is last on the list
|
||||
def cpu_last(a: DeviceParams, b: DeviceParams):
|
||||
if a.device == "cpu" and b.device == "cpu":
|
||||
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
|
||||
if a.device == b.device:
|
||||
return 0
|
||||
|
||||
# any should be first, if it's available
|
||||
if a.device == "any":
|
||||
return -1
|
||||
|
||||
# cpu should be last, if it's available
|
||||
if a.device == "cpu":
|
||||
return 1
|
||||
|
||||
return -1
|
||||
|
||||
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
||||
available_platforms = sorted(available_platforms, key=cmp_to_key(any_first_cpu_last))
|
||||
|
||||
logger.info(
|
||||
"available acceleration platforms: %s",
|
||||
|
@ -521,6 +523,7 @@ def img2img():
|
|||
upscale,
|
||||
source_image,
|
||||
strength,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
@ -535,7 +538,14 @@ def txt2img():
|
|||
logger.info("txt2img job queued for: %s", output)
|
||||
|
||||
executor.submit(
|
||||
output, run_txt2img_pipeline, context, params, size, output, upscale
|
||||
output,
|
||||
run_txt2img_pipeline,
|
||||
context,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
@ -605,6 +615,7 @@ def inpaint():
|
|||
mask_filter,
|
||||
strength,
|
||||
fill_color,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||
|
@ -634,6 +645,7 @@ def upscale():
|
|||
output,
|
||||
upscale,
|
||||
source_image,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
@ -711,7 +723,7 @@ def chain():
|
|||
# build and run chain pipeline
|
||||
empty_source = Image.new("RGB", (size.width, size.height))
|
||||
executor.submit(
|
||||
output, pipeline, context, params, empty_source, output=output, size=size
|
||||
output, pipeline, context, params, empty_source, output=output, size=size, needs_device=device
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
|
|
|
@ -32,6 +32,8 @@ export const MODEL_LABELS = {
|
|||
|
||||
export const PLATFORM_LABELS: Record<string, string> = {
|
||||
amd: 'AMD GPU',
|
||||
// eslint-disable-next-line id-blacklist
|
||||
any: 'Any Platform',
|
||||
cpu: 'CPU',
|
||||
cuda: 'CUDA',
|
||||
directml: 'DirectML',
|
||||
|
|
Loading…
Reference in New Issue