reuse config from loaded UNet
This commit is contained in:
parent
0175d7edcf
commit
acde899559
|
@ -13,7 +13,7 @@ from logging import getLogger
|
|||
from os import mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
|
@ -21,7 +21,6 @@ from diffusers import (
|
|||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionPipeline,
|
||||
|
@ -60,11 +59,15 @@ def convert_diffusion_diffusers_cnet(
|
|||
unet_sample_size,
|
||||
num_tokens,
|
||||
text_hidden_size,
|
||||
unet: Optional[Any] = None,
|
||||
):
|
||||
# CNet
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
|
||||
device=device, dtype=dtype
|
||||
)
|
||||
if unet is not None:
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
|
||||
else:
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
||||
|
||||
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipe_cnet.set_attn_processor(AttnProcessor())
|
||||
|
@ -350,9 +353,8 @@ def convert_diffusion_diffusers(
|
|||
convert_attribute=False,
|
||||
)
|
||||
|
||||
del pipeline.unet
|
||||
|
||||
if not single_vae:
|
||||
# if converting only the CNet, the rest of the model has already been converted
|
||||
convert_diffusion_diffusers_cnet(
|
||||
conversion,
|
||||
source,
|
||||
|
@ -363,10 +365,13 @@ def convert_diffusion_diffusers(
|
|||
unet_sample_size,
|
||||
num_tokens,
|
||||
text_hidden_size,
|
||||
unet=pipeline.unet,
|
||||
)
|
||||
else:
|
||||
logger.debug("skipping CNet for single-VAE model")
|
||||
|
||||
del pipeline.unet
|
||||
|
||||
if cnet_only:
|
||||
logger.info("done converting CNet")
|
||||
return (True, dest_path)
|
||||
|
|
|
@ -2,12 +2,13 @@
|
|||
numpy>=1.20,<1.24
|
||||
protobuf<4,>=3.20.2
|
||||
|
||||
### AI packages ###
|
||||
### SD packages ###
|
||||
accelerate
|
||||
coloredlogs
|
||||
controlnet_aux
|
||||
diffusers
|
||||
mediapipe
|
||||
omegaconf
|
||||
onnx
|
||||
# onnxruntime has many platform-specific packages
|
||||
safetensors
|
||||
|
|
Loading…
Reference in New Issue