1
0
Fork 0

add env var for torch extraction

This commit is contained in:
Sean Sube 2023-05-20 20:06:54 -05:00
parent 9c28154dee
commit 9fa1f0387b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 7 additions and 3 deletions

View File

@ -320,6 +320,7 @@ def convert_diffusion_diffusers(
if format == "safetensors": if format == "safetensors":
pipe_args["from_safetensors"] = True pipe_args["from_safetensors"] = True
torch_source = None
if path.exists(source) and path.isdir(source): if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source) logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained( pipeline = pipe_class.from_pretrained(
@ -470,10 +471,11 @@ def convert_diffusion_diffusers(
run_gc() run_gc()
if conversion.control and not single_vae and not conversion.share_unet: if conversion.control and not single_vae and not conversion.share_unet:
logger.info("loading and converting CNet") cnet_source = torch_source or source
logger.info("loading and converting CNet from %s", cnet_source)
cnet_path = convert_diffusion_diffusers_cnet( cnet_path = convert_diffusion_diffusers_cnet(
conversion, conversion,
source, cnet_source,
device, device,
output_path, output_path,
dtype, dtype,
@ -485,7 +487,6 @@ def convert_diffusion_diffusers(
v2=v2, v2=v2,
) )
if cnet_path is not None: if cnet_path is not None:
collate_cnet(cnet_path) collate_cnet(cnet_path)

View File

@ -45,11 +45,13 @@ class ConversionContext(ServerContext):
control: bool = True, control: bool = True,
reload: bool = True, reload: bool = True,
share_unet: bool = True, share_unet: bool = True,
extract: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs) super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
self.control = control self.control = control
self.extract = extract
self.half = half self.half = half
self.opset = opset self.opset = opset
self.prune = prune or [] self.prune = prune or []
@ -68,6 +70,7 @@ class ConversionContext(ServerContext):
def from_environ(cls): def from_environ(cls):
context = super().from_environ() context = super().from_environ()
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True) context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True) context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True) context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET)) context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))