1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-07 08:02:53 -06:00
parent 6e8d51b9fa
commit 9d9bd1a639
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 34 additions and 13 deletions

View File

@ -147,7 +147,14 @@ def load_pipeline(
lpw: bool, lpw: bool,
inversion: Optional[str], inversion: Optional[str],
): ):
pipe_key = (pipeline.__name__, model, device.device, device.provider, lpw, inversion) pipe_key = (
pipeline.__name__,
model,
device.device,
device.provider,
lpw,
inversion,
)
scheduler_key = (scheduler_name, model) scheduler_key = (scheduler_name, model)
scheduler_type = get_pipeline_schedulers()[scheduler_name] scheduler_type = get_pipeline_schedulers()[scheduler_name]

View File

@ -0,0 +1,2 @@
def expand_prompt(prompt: str) -> str:
return prompt

View File

@ -13,10 +13,8 @@ class ModelCache:
self.limit = limit self.limit = limit
def drop(self, tag: str, key: Any) -> int: def drop(self, tag: str, key: Any) -> int:
logger.debug("dropping item from cache: %s", tag) logger.debug("dropping item from cache: %s %s", tag, key)
removed = [ removed = [model for model in self.cache if model[0] == tag and model[1] == key]
model for model in self.cache if model[0] == tag and model[1] == key
]
for item in removed: for item in removed:
self.cache.remove(item) self.cache.remove(item)
@ -59,4 +57,4 @@ class ModelCache:
@property @property
def size(self): def size(self):
return len(self.cache) return len(self.cache)

View File

@ -285,7 +285,9 @@ class DevicePoolExecutor:
worker.join(self.join_timeout) worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.warning( logger.warning(
"worker %s for device %s could not be stopped in time", worker.pid, device "worker %s for device %s could not be stopped in time",
worker.pid,
device,
) )
self.leaking.append((device, worker)) self.leaking.append((device, worker))
else: else:
@ -301,11 +303,15 @@ class DevicePoolExecutor:
if len(self.leaking) > 0: if len(self.leaking) > 0:
logger.warning("cleaning up %s leaking workers", len(self.leaking)) logger.warning("cleaning up %s leaking workers", len(self.leaking))
for device, worker in self.leaking: for device, worker in self.leaking:
logger.debug("shutting down worker %s for device %s", worker.pid, device) logger.debug(
"shutting down worker %s for device %s", worker.pid, device
)
worker.join(self.join_timeout) worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.error( logger.error(
"leaking worker %s for device %s could not be shut down", worker.pid, device "leaking worker %s for device %s could not be shut down",
worker.pid,
device,
) )
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()] self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
@ -328,7 +334,9 @@ class DevicePoolExecutor:
worker.join(self.join_timeout) worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.warning( logger.warning(
"worker %s for device %s could not be recycled in time", worker.pid, device "worker %s for device %s could not be recycled in time",
worker.pid,
device,
) )
self.leaking.append((device, worker)) self.leaking.append((device, worker))
else: else:
@ -338,7 +346,9 @@ class DevicePoolExecutor:
needs_restart.append(device) needs_restart.append(device)
else: else:
logger.debug( logger.debug(
"worker %s for device %s does not need to be recycled", worker.pid, device "worker %s for device %s does not need to be recycled",
worker.pid,
device,
) )
logger.debug("starting new workers") logger.debug("starting new workers")

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from queue import Empty
from os import getpid from os import getpid
from queue import Empty
from sys import exit from sys import exit
from traceback import format_exception from traceback import format_exception
@ -32,7 +32,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True: while True:
try: try:
if not context.is_current(): if not context.is_current():
logger.warning("worker %s has been replaced by %s, exiting", getpid(), context.get_current()) logger.warning(
"worker %s has been replaced by %s, exiting",
getpid(),
context.get_current(),
)
exit(EXIT_REPLACED) exit(EXIT_REPLACED)
name, fn, args, kwargs = context.pending.get(timeout=1.0) name, fn, args, kwargs = context.pending.get(timeout=1.0)