1
0
Fork 0

null checks

This commit is contained in:
Sean Sube 2023-12-26 08:07:05 -06:00
parent 2b65077d82
commit eb8bd145c9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 33 additions and 26 deletions

View File

@ -269,7 +269,12 @@ def convert_model_diffusion(conversion: ConversionContext, model):
model_format = source_format(model)
pipeline = model.get("pipeline", "txt2img")
logger.trace("converting diffusion model using pipeline %s", pipeline)
converter = model_converters.get(pipeline)
if converter is None:
raise ValueError("cannot find converter for pipeline")
converted, dest = converter(
conversion,
model,
@ -500,7 +505,7 @@ def register_plugins(conversion: ConversionContext):
logger.info("loading conversion plugins")
exports = load_plugins(conversion)
for proto, client in exports.clients:
for proto, client in exports.clients.items():
try:
add_model_source(proto, client)
except Exception:

View File

@ -198,6 +198,8 @@ def expand_prompt(
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
if negative_prompt_embeds is not None:
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1]
logger.trace(
"padding negative prompt to match input: %s, %s, %s extra tokens",
@ -427,7 +429,7 @@ def parse_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) ->
wildcard = ""
if name in wildcards:
wildcard = pop_random(wildcards.get(name))
wildcard = pop_random(wildcards[name])
else:
logger.warning("unknown wildcard: %s", name)
@ -488,11 +490,11 @@ def parse_region_group(group: Tuple[str, ...]) -> Region:
top, left, bottom, right, weight, feather, prompt = group
# break down the feather section
feather_radius, *feather_edges = feather.split("_")
if len(feather_edges) == 0:
feather_radius, *feather_rest = feather.split("_")
if len(feather_rest) == 0:
feather_edges = "TLBR"
else:
feather_edges = "".join(feather_edges)
feather_edges = "".join(feather_rest)
return (
int(top),

View File

@ -536,7 +536,7 @@ def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
while True:
try:
msg = logs.get(pool.join_timeout / 2)
msg = logs.get(timeout=(pool.join_timeout / 2))
logger.debug("received logs from worker: %s", msg)
except Empty:
# logger worker should not generate more logs if it doesn't have any logs