feat(api): add option to reload CNet for conversion
This commit is contained in:
parent
20107f559a
commit
9c28154dee
|
@ -450,8 +450,8 @@ def convert_diffusion_diffusers(
|
|||
)
|
||||
|
||||
cnet_path = None
|
||||
if conversion.control and not single_vae:
|
||||
# if converting only the CNet, the rest of the model has already been converted
|
||||
if conversion.control and not single_vae and conversion.share_unet:
|
||||
logger.debug("converting CNet from loaded UNet")
|
||||
cnet_path = convert_diffusion_diffusers_cnet(
|
||||
conversion,
|
||||
source,
|
||||
|
@ -465,12 +465,27 @@ def convert_diffusion_diffusers(
|
|||
unet=pipeline.unet,
|
||||
v2=v2,
|
||||
)
|
||||
else:
|
||||
logger.debug("skipping CNet for single-VAE model")
|
||||
|
||||
del pipeline.unet
|
||||
run_gc()
|
||||
|
||||
if conversion.control and not single_vae and not conversion.share_unet:
|
||||
logger.info("loading and converting CNet")
|
||||
cnet_path = convert_diffusion_diffusers_cnet(
|
||||
conversion,
|
||||
source,
|
||||
device,
|
||||
output_path,
|
||||
dtype,
|
||||
unet_in_channels,
|
||||
unet_sample_size,
|
||||
num_tokens,
|
||||
text_hidden_size,
|
||||
unet=None,
|
||||
v2=v2,
|
||||
)
|
||||
|
||||
|
||||
if cnet_path is not None:
|
||||
collate_cnet(cnet_path)
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ class ConversionContext(ServerContext):
|
|||
prune: Optional[List[str]] = None,
|
||||
control: bool = True,
|
||||
reload: bool = True,
|
||||
share_unet: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
||||
|
@ -53,6 +54,7 @@ class ConversionContext(ServerContext):
|
|||
self.opset = opset
|
||||
self.prune = prune or []
|
||||
self.reload = reload
|
||||
self.share_unet = share_unet
|
||||
self.token = token
|
||||
|
||||
if device is not None:
|
||||
|
@ -66,8 +68,9 @@ class ConversionContext(ServerContext):
|
|||
def from_environ(cls):
|
||||
context = super().from_environ()
|
||||
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
||||
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
||||
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
||||
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
||||
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
||||
return context
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue