fix(api): assume inversion tokens are embeddings for now
This commit is contained in:
parent
17a7cdae56
commit
e5862d178c
|
@ -217,8 +217,8 @@ def load_pipeline(
|
||||||
|
|
||||||
text_encoder = None
|
text_encoder = None
|
||||||
if inversions is not None and len(inversions) > 0:
|
if inversions is not None and len(inversions) > 0:
|
||||||
|
logger.debug("blending Textual Inversions from %s", inversions)
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
logger.debug("blending Textual Inversions from %s", inversion_names)
|
|
||||||
|
|
||||||
inversion_models = [
|
inversion_models = [
|
||||||
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
||||||
|
@ -233,7 +233,14 @@ def load_pipeline(
|
||||||
server,
|
server,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
list(zip(inversion_models, inversion_weights, inversion_names)),
|
list(
|
||||||
|
zip(
|
||||||
|
inversion_models,
|
||||||
|
inversion_weights,
|
||||||
|
inversion_names,
|
||||||
|
[None] * len(inversion_models),
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
components["tokenizer"] = tokenizer
|
components["tokenizer"] = tokenizer
|
||||||
|
|
|
@ -23,8 +23,8 @@ class DevicePoolExecutor:
|
||||||
join_timeout: float
|
join_timeout: float
|
||||||
|
|
||||||
leaking: List[Tuple[str, Process]]
|
leaking: List[Tuple[str, Process]]
|
||||||
context: Dict[str, WorkerContext] # Device -> Context
|
context: Dict[str, WorkerContext] # Device -> Context
|
||||||
current: Dict[str, "Value[int]"] # Device -> pid
|
current: Dict[str, "Value[int]"] # Device -> pid
|
||||||
pending: Dict[str, "Queue[JobCommand]"]
|
pending: Dict[str, "Queue[JobCommand]"]
|
||||||
threads: Dict[str, Thread]
|
threads: Dict[str, Thread]
|
||||||
workers: Dict[str, Process]
|
workers: Dict[str, Process]
|
||||||
|
@ -202,7 +202,9 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
for job in self.pending_jobs:
|
for job in self.pending_jobs:
|
||||||
if job.name == key:
|
if job.name == key:
|
||||||
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key]
|
self.pending_jobs[:] = [
|
||||||
|
job for job in self.pending_jobs if job.name != key
|
||||||
|
]
|
||||||
logger.info("cancelled pending job: %s", key)
|
logger.info("cancelled pending job: %s", key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -387,7 +389,8 @@ class DevicePoolExecutor:
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
False,
|
False,
|
||||||
) for job in self.pending_jobs
|
)
|
||||||
|
for job in self.pending_jobs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
history.extend(
|
history.extend(
|
||||||
|
@ -420,7 +423,9 @@ class DevicePoolExecutor:
|
||||||
"progress update for job: %s to %s", progress.job, progress.progress
|
"progress update for job: %s to %s", progress.job, progress.progress
|
||||||
)
|
)
|
||||||
self.running_jobs[progress.job] = progress
|
self.running_jobs[progress.job] = progress
|
||||||
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job]
|
self.pending_jobs[:] = [
|
||||||
|
job for job in self.pending_jobs if job.name != progress.job
|
||||||
|
]
|
||||||
|
|
||||||
if progress.job in self.cancelled_jobs:
|
if progress.job in self.cancelled_jobs:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
Loading…
Reference in New Issue