From e5862d178cdeafa15d108adea162a668fdea0693 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 18:35:11 -0500 Subject: [PATCH] fix(api): assume inversion tokens are embeddings for now --- api/onnx_web/diffusers/load.py | 11 +++++++++-- api/onnx_web/worker/pool.py | 15 ++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 34b71672..0775f3e7 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -217,8 +217,8 @@ def load_pipeline( text_encoder = None if inversions is not None and len(inversions) > 0: + logger.debug("blending Textual Inversions from %s", inversions) inversion_names, inversion_weights = zip(*inversions) - logger.debug("blending Textual Inversions from %s", inversion_names) inversion_models = [ path.join(server.model_path, "inversion", f"{name}.ckpt") @@ -233,7 +233,14 @@ def load_pipeline( server, text_encoder, tokenizer, - list(zip(inversion_models, inversion_weights, inversion_names)), + list( + zip( + inversion_models, + inversion_weights, + inversion_names, + [None] * len(inversion_models), + ) + ), ) components["tokenizer"] = tokenizer diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index d50fede4..31bd0df1 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -23,8 +23,8 @@ class DevicePoolExecutor: join_timeout: float leaking: List[Tuple[str, Process]] - context: Dict[str, WorkerContext] # Device -> Context - current: Dict[str, "Value[int]"] # Device -> pid + context: Dict[str, WorkerContext] # Device -> Context + current: Dict[str, "Value[int]"] # Device -> pid pending: Dict[str, "Queue[JobCommand]"] threads: Dict[str, Thread] workers: Dict[str, Process] @@ -202,7 +202,9 @@ class DevicePoolExecutor: for job in self.pending_jobs: if job.name == key: - self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key] + self.pending_jobs[:] = [ + job for job in self.pending_jobs if job.name != key + ] logger.info("cancelled pending job: %s", key) return True @@ -387,7 +389,8 @@ class DevicePoolExecutor: False, False, False, - ) for job in self.pending_jobs + ) + for job in self.pending_jobs ] ) history.extend( @@ -420,7 +423,9 @@ class DevicePoolExecutor: "progress update for job: %s to %s", progress.job, progress.progress ) self.running_jobs[progress.job] = progress - self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job] + self.pending_jobs[:] = [ + job for job in self.pending_jobs if job.name != progress.job + ] if progress.job in self.cancelled_jobs: logger.debug(