1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-07 08:02:53 -06:00
parent 6e8d51b9fa
commit 9d9bd1a639
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 34 additions and 13 deletions

View File

@ -147,7 +147,14 @@ def load_pipeline(
lpw: bool,
inversion: Optional[str],
):
pipe_key = (pipeline.__name__, model, device.device, device.provider, lpw, inversion)
pipe_key = (
pipeline.__name__,
model,
device.device,
device.provider,
lpw,
inversion,
)
scheduler_key = (scheduler_name, model)
scheduler_type = get_pipeline_schedulers()[scheduler_name]

View File

@ -0,0 +1,2 @@
def expand_prompt(prompt: str) -> str:
return prompt

View File

@ -13,10 +13,8 @@ class ModelCache:
self.limit = limit
def drop(self, tag: str, key: Any) -> int:
logger.debug("dropping item from cache: %s", tag)
removed = [
model for model in self.cache if model[0] == tag and model[1] == key
]
logger.debug("dropping item from cache: %s %s", tag, key)
removed = [model for model in self.cache if model[0] == tag and model[1] == key]
for item in removed:
self.cache.remove(item)

View File

@ -285,7 +285,9 @@ class DevicePoolExecutor:
worker.join(self.join_timeout)
if worker.is_alive():
logger.warning(
"worker %s for device %s could not be stopped in time", worker.pid, device
"worker %s for device %s could not be stopped in time",
worker.pid,
device,
)
self.leaking.append((device, worker))
else:
@ -301,11 +303,15 @@ class DevicePoolExecutor:
if len(self.leaking) > 0:
logger.warning("cleaning up %s leaking workers", len(self.leaking))
for device, worker in self.leaking:
logger.debug("shutting down worker %s for device %s", worker.pid, device)
logger.debug(
"shutting down worker %s for device %s", worker.pid, device
)
worker.join(self.join_timeout)
if worker.is_alive():
logger.error(
"leaking worker %s for device %s could not be shut down", worker.pid, device
"leaking worker %s for device %s could not be shut down",
worker.pid,
device,
)
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
@ -328,7 +334,9 @@ class DevicePoolExecutor:
worker.join(self.join_timeout)
if worker.is_alive():
logger.warning(
"worker %s for device %s could not be recycled in time", worker.pid, device
"worker %s for device %s could not be recycled in time",
worker.pid,
device,
)
self.leaking.append((device, worker))
else:
@ -338,7 +346,9 @@ class DevicePoolExecutor:
needs_restart.append(device)
else:
logger.debug(
"worker %s for device %s does not need to be recycled", worker.pid, device
"worker %s for device %s does not need to be recycled",
worker.pid,
device,
)
logger.debug("starting new workers")

View File

@ -1,6 +1,6 @@
from logging import getLogger
from queue import Empty
from os import getpid
from queue import Empty
from sys import exit
from traceback import format_exception
@ -32,7 +32,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True:
try:
if not context.is_current():
logger.warning("worker %s has been replaced by %s, exiting", getpid(), context.get_current())
logger.warning(
"worker %s has been replaced by %s, exiting",
getpid(),
context.get_current(),
)
exit(EXIT_REPLACED)
name, fn, args, kwargs = context.pending.get(timeout=1.0)