diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 8bd5d1a1..b7515b98 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -269,7 +269,12 @@ def convert_model_diffusion(conversion: ConversionContext, model): model_format = source_format(model) pipeline = model.get("pipeline", "txt2img") + logger.trace("converting diffusion model using pipeline %s", pipeline) + converter = model_converters.get(pipeline) + if converter is None: + raise ValueError("cannot find converter for pipeline") + converted, dest = converter( conversion, model, @@ -500,7 +505,7 @@ def register_plugins(conversion: ConversionContext): logger.info("loading conversion plugins") exports = load_plugins(conversion) - for proto, client in exports.clients: + for proto, client in exports.clients.items(): try: add_model_source(proto, client) except Exception: diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 1c6f8ed0..d83d0f98 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -198,27 +198,29 @@ def expand_prompt( negative_prompt_embeds = self.text_encoder( input_ids=uncond_input.input_ids.astype(np.int32) )[0] - negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] - logger.trace( - "padding negative prompt to match input: %s, %s, %s extra tokens", - tokens.input_ids.shape, - negative_prompt_embeds.shape, - negative_padding, - ) - negative_prompt_embeds = np.pad( - negative_prompt_embeds, - [(0, 0), (0, negative_padding), (0, 0)], - mode="constant", - constant_values=0, - ) - negative_prompt_embeds = np.repeat( - negative_prompt_embeds, num_images_per_prompt, axis=0 - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + if negative_prompt_embeds is not None: + negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] + logger.trace( + "padding negative prompt to match input: %s, %s, %s extra tokens", + tokens.input_ids.shape, + negative_prompt_embeds.shape, + negative_padding, + ) + negative_prompt_embeds = np.pad( + negative_prompt_embeds, + [(0, 0), (0, negative_padding), (0, 0)], + mode="constant", + constant_values=0, + ) + negative_prompt_embeds = np.repeat( + negative_prompt_embeds, num_images_per_prompt, axis=0 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) logger.trace("expanded prompt shape: %s", prompt_embeds.shape) return prompt_embeds @@ -427,7 +429,7 @@ def parse_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> wildcard = "" if name in wildcards: - wildcard = pop_random(wildcards.get(name)) + wildcard = pop_random(wildcards[name]) else: logger.warning("unknown wildcard: %s", name) @@ -488,11 +490,11 @@ def parse_region_group(group: Tuple[str, ...]) -> Region: top, left, bottom, right, weight, feather, prompt = group # break down the feather section - feather_radius, *feather_edges = feather.split("_") - if len(feather_edges) == 0: + feather_radius, *feather_rest = feather.split("_") + if len(feather_rest) == 0: feather_edges = "TLBR" else: - feather_edges = "".join(feather_edges) + feather_edges = "".join(feather_rest) return ( int(top), diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 3b0d32a8..d210421e 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -536,7 +536,7 @@ def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"): while True: try: - msg = logs.get(pool.join_timeout / 2) + msg = logs.get(timeout=(pool.join_timeout / 2)) logger.debug("received logs from worker: %s", msg) except Empty: # logger worker should not generate more logs if it doesn't have any logs