fix(api): run GC during diffusers conversion, add flag to skip ControlNet (#369)
This commit is contained in:
parent
e96bd0fca4
commit
e2035c3fbf
|
@ -35,7 +35,7 @@ from ...diffusers.load import optimize_pipeline
|
||||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export, run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -221,6 +221,9 @@ def convert_diffusion_diffusers_cnet(
|
||||||
external_data=True, # UNet is > 2GB, so the weights need to be split
|
external_data=True, # UNet is > 2GB, so the weights need to be split
|
||||||
v2=v2,
|
v2=v2,
|
||||||
)
|
)
|
||||||
|
del pipe_cnet
|
||||||
|
run_gc()
|
||||||
|
|
||||||
cnet_model_path = str(cnet_path.absolute().as_posix())
|
cnet_model_path = str(cnet_path.absolute().as_posix())
|
||||||
cnet_dir = path.dirname(cnet_model_path)
|
cnet_dir = path.dirname(cnet_model_path)
|
||||||
cnet = load_model(cnet_model_path)
|
cnet = load_model(cnet_model_path)
|
||||||
|
@ -238,7 +241,8 @@ def convert_diffusion_diffusers_cnet(
|
||||||
location=ONNX_WEIGHTS,
|
location=ONNX_WEIGHTS,
|
||||||
convert_attribute=False,
|
convert_attribute=False,
|
||||||
)
|
)
|
||||||
del pipe_cnet
|
del cnet
|
||||||
|
run_gc()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -365,6 +369,7 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
|
|
||||||
del pipeline.text_encoder
|
del pipeline.text_encoder
|
||||||
|
run_gc()
|
||||||
|
|
||||||
# UNET
|
# UNET
|
||||||
logger.debug("UNET config: %s", pipeline.unet.config)
|
logger.debug("UNET config: %s", pipeline.unet.config)
|
||||||
|
@ -427,7 +432,7 @@ def convert_diffusion_diffusers(
|
||||||
convert_attribute=False,
|
convert_attribute=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not single_vae:
|
if not single_vae or not conversion.control:
|
||||||
# if converting only the CNet, the rest of the model has already been converted
|
# if converting only the CNet, the rest of the model has already been converted
|
||||||
convert_diffusion_diffusers_cnet(
|
convert_diffusion_diffusers_cnet(
|
||||||
conversion,
|
conversion,
|
||||||
|
@ -446,6 +451,7 @@ def convert_diffusion_diffusers(
|
||||||
logger.debug("skipping CNet for single-VAE model")
|
logger.debug("skipping CNet for single-VAE model")
|
||||||
|
|
||||||
del pipeline.unet
|
del pipeline.unet
|
||||||
|
run_gc()
|
||||||
|
|
||||||
if cnet_only:
|
if cnet_only:
|
||||||
logger.info("done converting CNet")
|
logger.info("done converting CNet")
|
||||||
|
@ -533,6 +539,7 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
|
|
||||||
del pipeline.vae
|
del pipeline.vae
|
||||||
|
run_gc()
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
|
onnx_pipeline = OnnxStableDiffusionUpscalePipeline(
|
||||||
|
@ -563,6 +570,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
del pipeline
|
del pipeline
|
||||||
del onnx_pipeline
|
del onnx_pipeline
|
||||||
|
run_gc()
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
_ = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
|
|
@ -35,14 +35,16 @@ class ConversionContext(ServerContext):
|
||||||
model_path: Optional[str] = None,
|
model_path: Optional[str] = None,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
half: Optional[bool] = False,
|
half: bool = False,
|
||||||
opset: Optional[int] = None,
|
opset: Optional[int] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
prune: Optional[List[str]] = None,
|
prune: Optional[List[str]] = None,
|
||||||
|
control: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
||||||
|
|
||||||
|
self.control = control
|
||||||
self.half = half
|
self.half = half
|
||||||
self.opset = opset
|
self.opset = opset
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
|
@ -75,6 +75,7 @@ def update_extra_models(server: ServerContext):
|
||||||
conversion_lock = True
|
conversion_lock = True
|
||||||
|
|
||||||
from onnx_web.convert.__main__ import main as convert
|
from onnx_web.convert.__main__ import main as convert
|
||||||
|
|
||||||
convert(
|
convert(
|
||||||
args=[
|
args=[
|
||||||
"--correction",
|
"--correction",
|
||||||
|
|
Loading…
Reference in New Issue