add env var for torch extraction
This commit is contained in:
parent
9c28154dee
commit
9fa1f0387b
|
@ -320,6 +320,7 @@ def convert_diffusion_diffusers(
|
||||||
if format == "safetensors":
|
if format == "safetensors":
|
||||||
pipe_args["from_safetensors"] = True
|
pipe_args["from_safetensors"] = True
|
||||||
|
|
||||||
|
torch_source = None
|
||||||
if path.exists(source) and path.isdir(source):
|
if path.exists(source) and path.isdir(source):
|
||||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
|
@ -470,10 +471,11 @@ def convert_diffusion_diffusers(
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
if conversion.control and not single_vae and not conversion.share_unet:
|
if conversion.control and not single_vae and not conversion.share_unet:
|
||||||
logger.info("loading and converting CNet")
|
cnet_source = torch_source or source
|
||||||
|
logger.info("loading and converting CNet from %s", cnet_source)
|
||||||
cnet_path = convert_diffusion_diffusers_cnet(
|
cnet_path = convert_diffusion_diffusers_cnet(
|
||||||
conversion,
|
conversion,
|
||||||
source,
|
cnet_source,
|
||||||
device,
|
device,
|
||||||
output_path,
|
output_path,
|
||||||
dtype,
|
dtype,
|
||||||
|
@ -485,7 +487,6 @@ def convert_diffusion_diffusers(
|
||||||
v2=v2,
|
v2=v2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if cnet_path is not None:
|
if cnet_path is not None:
|
||||||
collate_cnet(cnet_path)
|
collate_cnet(cnet_path)
|
||||||
|
|
||||||
|
|
|
@ -45,11 +45,13 @@ class ConversionContext(ServerContext):
|
||||||
control: bool = True,
|
control: bool = True,
|
||||||
reload: bool = True,
|
reload: bool = True,
|
||||||
share_unet: bool = True,
|
share_unet: bool = True,
|
||||||
|
extract: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
|
||||||
|
|
||||||
self.control = control
|
self.control = control
|
||||||
|
self.extract = extract
|
||||||
self.half = half
|
self.half = half
|
||||||
self.opset = opset
|
self.opset = opset
|
||||||
self.prune = prune or []
|
self.prune = prune or []
|
||||||
|
@ -68,6 +70,7 @@ class ConversionContext(ServerContext):
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
context = super().from_environ()
|
context = super().from_environ()
|
||||||
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
||||||
|
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
|
||||||
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
||||||
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
||||||
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
||||||
|
|
Loading…
Reference in New Issue