From 9d9bd1a639b20ad20b31b086a5828241ba2f16d8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 7 Mar 2023 08:02:53 -0600 Subject: [PATCH] apply lint --- api/onnx_web/diffusers/load.py | 9 ++++++++- api/onnx_web/diffusers/utils.py | 2 ++ api/onnx_web/server/model_cache.py | 8 +++----- api/onnx_web/worker/pool.py | 20 +++++++++++++++----- api/onnx_web/worker/worker.py | 8 ++++++-- 5 files changed, 34 insertions(+), 13 deletions(-) create mode 100644 api/onnx_web/diffusers/utils.py diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 2c8db970..29ff91bb 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -147,7 +147,14 @@ def load_pipeline( lpw: bool, inversion: Optional[str], ): - pipe_key = (pipeline.__name__, model, device.device, device.provider, lpw, inversion) + pipe_key = ( + pipeline.__name__, + model, + device.device, + device.provider, + lpw, + inversion, + ) scheduler_key = (scheduler_name, model) scheduler_type = get_pipeline_schedulers()[scheduler_name] diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py new file mode 100644 index 00000000..d5d782cd --- /dev/null +++ b/api/onnx_web/diffusers/utils.py @@ -0,0 +1,2 @@ +def expand_prompt(prompt: str) -> str: + return prompt diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 00717223..d757a94f 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -13,10 +13,8 @@ class ModelCache: self.limit = limit def drop(self, tag: str, key: Any) -> int: - logger.debug("dropping item from cache: %s", tag) - removed = [ - model for model in self.cache if model[0] == tag and model[1] == key - ] + logger.debug("dropping item from cache: %s %s", tag, key) + removed = [model for model in self.cache if model[0] == tag and model[1] == key] for item in removed: self.cache.remove(item) @@ -59,4 +57,4 @@ class ModelCache: @property def size(self): - return len(self.cache) \ No newline at end of file + return len(self.cache) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 3a7cf0b0..c982f6df 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -285,7 +285,9 @@ class DevicePoolExecutor: worker.join(self.join_timeout) if worker.is_alive(): logger.warning( - "worker %s for device %s could not be stopped in time", worker.pid, device + "worker %s for device %s could not be stopped in time", + worker.pid, + device, ) self.leaking.append((device, worker)) else: @@ -301,11 +303,15 @@ class DevicePoolExecutor: if len(self.leaking) > 0: logger.warning("cleaning up %s leaking workers", len(self.leaking)) for device, worker in self.leaking: - logger.debug("shutting down worker %s for device %s", worker.pid, device) + logger.debug( + "shutting down worker %s for device %s", worker.pid, device + ) worker.join(self.join_timeout) if worker.is_alive(): logger.error( - "leaking worker %s for device %s could not be shut down", worker.pid, device + "leaking worker %s for device %s could not be shut down", + worker.pid, + device, ) self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()] @@ -328,7 +334,9 @@ class DevicePoolExecutor: worker.join(self.join_timeout) if worker.is_alive(): logger.warning( - "worker %s for device %s could not be recycled in time", worker.pid, device + "worker %s for device %s could not be recycled in time", + worker.pid, + device, ) self.leaking.append((device, worker)) else: @@ -338,7 +346,9 @@ class DevicePoolExecutor: needs_restart.append(device) else: logger.debug( - "worker %s for device %s does not need to be recycled", worker.pid, device + "worker %s for device %s does not need to be recycled", + worker.pid, + device, ) logger.debug("starting new workers") diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 69d9a120..87c33f67 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,6 +1,6 @@ from logging import getLogger -from queue import Empty from os import getpid +from queue import Empty from sys import exit from traceback import format_exception @@ -32,7 +32,11 @@ def worker_main(context: WorkerContext, server: ServerContext): while True: try: if not context.is_current(): - logger.warning("worker %s has been replaced by %s, exiting", getpid(), context.get_current()) + logger.warning( + "worker %s has been replaced by %s, exiting", + getpid(), + context.get_current(), + ) exit(EXIT_REPLACED) name, fn, args, kwargs = context.pending.get(timeout=1.0)