1
0
Fork 0

make SD conversion more like SDXL

This commit is contained in:
Sean Sube 2023-12-23 21:42:36 -06:00
parent 75ac764d42
commit 2b8b59a39c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 7 deletions

View File

@ -24,9 +24,6 @@ from diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...constants import ONNX_MODEL, ONNX_WEIGHTS
@ -388,6 +385,7 @@ def convert_diffusion_diffusers(
return (False, dest_path) return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format) cache_path = fetch_model(conversion, name, source, format=format)
temp_path = path.join(conversion.cache_path, f"{name}-torch")
pipe_class = CONVERT_PIPELINES.get(pipe_type) pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version( v2, pipe_args = get_model_version(
@ -417,9 +415,9 @@ def convert_diffusion_diffusers(
torch_source = convert_extract_checkpoint( torch_source = convert_extract_checkpoint(
conversion, conversion,
cache_path, cache_path,
f"{name}-torch", temp_path,
is_inpainting=is_inpainting, is_inpainting=is_inpainting,
config_file=config, config_file=config_path,
vae_file=replace_vae, vae_file=replace_vae,
) )
logger.debug( logger.debug(
@ -434,10 +432,9 @@ def convert_diffusion_diffusers(
replace_vae = None replace_vae = None
else: else:
logger.debug("loading pipeline from SD checkpoint: %s", source) logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt( pipeline = pipe_class.from_single_file(
cache_path, cache_path,
original_config_file=config_path, original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args, **pipe_args,
).to(device, torch_dtype=dtype) ).to(device, torch_dtype=dtype)
elif source.startswith(HuggingfaceClient.protocol): elif source.startswith(HuggingfaceClient.protocol):