1
0
Fork 0

fix(api): include SD upscaling in diffusion prefixes

This commit is contained in:
Sean Sube 2023-12-08 22:26:19 -06:00
parent 46d9fc0dd4
commit 293a1bb184
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 15 additions and 14 deletions

View File

@ -26,6 +26,7 @@ from .utils import (
DEFAULT_OPSET, DEFAULT_OPSET,
ConversionContext, ConversionContext,
download_progress, download_progress,
fix_diffusion_name,
remove_prefix, remove_prefix,
source_format, source_format,
tuple_to_correction, tuple_to_correction,
@ -265,20 +266,6 @@ def fetch_model(
return source, False return source, False
DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-"]
def fix_diffusion_name(name: str):
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
logger.warning(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
return f"diffusion-{name}"
return name
def convert_models(conversion: ConversionContext, args, models: Models): def convert_models(conversion: ConversionContext, args, models: Models):
model_errors = [] model_errors = []

View File

@ -345,3 +345,17 @@ def onnx_export(
all_tensors_to_one_file=True, all_tensors_to_one_file=True,
location=ONNX_WEIGHTS, location=ONNX_WEIGHTS,
) )
DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"]
def fix_diffusion_name(name: str):
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
logger.warning(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
return f"diffusion-{name}"
return name