diff --git a/api/launch.sh b/api/launch.sh index 55b6ff72..59798b8f 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -21,7 +21,9 @@ python3 -m onnx_web.convert \ --diffusion \ --upscaling \ --correction \ - --token=${HF_TOKEN:-} + --extras=${ONNX_WEB_EXTRA_MODELS:-extras.json} \ + --token=${HF_TOKEN:-} \ + --half echo "Launching API server..." flask --app='onnx_web.main:run' run --host=0.0.0.0 diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 9049758e..a24fc80f 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -40,6 +40,7 @@ def load_stable_diffusion( model_path, provider=device.ort_provider(), sess_options=device.sess_options(), + torch_dtype=torch.float16, ) else: logger.debug( @@ -50,6 +51,7 @@ def load_stable_diffusion( pipe = StableDiffusionUpscalePipeline.from_pretrained( model_path, provider=device.provider, + torch_dtype=torch.float16, ) if not server.show_progress: diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index e3b54510..e5c8297c 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -3,6 +3,7 @@ from os import path from typing import Any, Optional, Tuple import numpy as np +import torch from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -164,6 +165,7 @@ def load_pipeline( provider=device.ort_provider(), sess_options=device.sess_options(), subfolder="scheduler", + torch_dtype=torch.float16, ) if device is not None and hasattr(scheduler, "to"):