1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-02-11 16:50:57 -06:00
parent 6d243e8f4f
commit 44c9524cd2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 29 additions and 9 deletions

View File

@ -185,9 +185,19 @@ class DevicePoolExecutor:
def prune(self): def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()] self.jobs[:] = [job for job in self.jobs if job.future.done()]
def submit(self, key: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs) -> None: def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
device = self.get_next_device(needs_device=needs_device) device = self.get_next_device(needs_device=needs_device)
logger.info("assigning job %s to device %s: %s", key, device, self.devices[device]) logger.info(
"assigning job %s to device %s: %s", key, device, self.devices[device]
)
context = JobContext(key, self.devices, device_index=device) context = JobContext(key, self.devices, device_index=device)
future = self.pool.submit(fn, context, *args, **kwargs) future = self.pool.submit(fn, context, *args, **kwargs)

View File

@ -333,7 +333,10 @@ def load_params(context: ServerContext):
config_params = yaml.safe_load(f) config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None: if "platform" in config_params and context.default_platform is not None:
logger.info("Overriding default platform from environment: %s", context.default_platform) logger.info(
"Overriding default platform from environment: %s",
context.default_platform,
)
config_platform = config_params.get("platform", {}) config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform config_platform["default"] = context.default_platform
@ -383,7 +386,9 @@ def load_platforms(context: ServerContext):
return -1 return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(any_first_cpu_last)) available_platforms = sorted(
available_platforms, key=cmp_to_key(any_first_cpu_last)
)
logger.info( logger.info(
"available acceleration platforms: %s", "available acceleration platforms: %s",
@ -728,7 +733,14 @@ def chain():
# build and run chain pipeline # build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height)) empty_source = Image.new("RGB", (size.width, size.height))
executor.submit( executor.submit(
output, pipeline, context, params, empty_source, output=output, size=size, needs_device=device output,
pipeline,
context,
params,
empty_source,
output=output,
size=size,
needs_device=device,
) )
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))

View File

@ -63,10 +63,8 @@ def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False) return get_boolean(environ, "DEBUG", False)
def get_boolean( def get_boolean(args: Any, key: str, default_value: bool) -> bool:
args: Any, key: str, default_value: bool return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes")
) -> bool:
return (args.get(key, str(default_value)).lower() in ('1', 't', 'true', 'y', 'yes'))
def get_and_clamp_float( def get_and_clamp_float(