make SD conversion more like SDXL
This commit is contained in:
parent
75ac764d42
commit
2b8b59a39c
|
@ -24,9 +24,6 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
|
||||||
download_from_original_stable_diffusion_ckpt,
|
|
||||||
)
|
|
||||||
from onnx import load_model, save_model
|
from onnx import load_model, save_model
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
|
@ -388,6 +385,7 @@ def convert_diffusion_diffusers(
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
cache_path = fetch_model(conversion, name, source, format=format)
|
cache_path = fetch_model(conversion, name, source, format=format)
|
||||||
|
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||||
|
|
||||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||||
v2, pipe_args = get_model_version(
|
v2, pipe_args = get_model_version(
|
||||||
|
@ -417,9 +415,9 @@ def convert_diffusion_diffusers(
|
||||||
torch_source = convert_extract_checkpoint(
|
torch_source = convert_extract_checkpoint(
|
||||||
conversion,
|
conversion,
|
||||||
cache_path,
|
cache_path,
|
||||||
f"{name}-torch",
|
temp_path,
|
||||||
is_inpainting=is_inpainting,
|
is_inpainting=is_inpainting,
|
||||||
config_file=config,
|
config_file=config_path,
|
||||||
vae_file=replace_vae,
|
vae_file=replace_vae,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -434,10 +432,9 @@ def convert_diffusion_diffusers(
|
||||||
replace_vae = None
|
replace_vae = None
|
||||||
else:
|
else:
|
||||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
pipeline = pipe_class.from_single_file(
|
||||||
cache_path,
|
cache_path,
|
||||||
original_config_file=config_path,
|
original_config_file=config_path,
|
||||||
pipeline_class=pipe_class,
|
|
||||||
**pipe_args,
|
**pipe_args,
|
||||||
).to(device, torch_dtype=dtype)
|
).to(device, torch_dtype=dtype)
|
||||||
elif source.startswith(HuggingfaceClient.protocol):
|
elif source.startswith(HuggingfaceClient.protocol):
|
||||||
|
|
Loading…
Reference in New Issue