test half precision on all diffusion models
This commit is contained in:
parent
1f9efb433a
commit
565873b3ae
|
@ -21,7 +21,9 @@ python3 -m onnx_web.convert \
|
||||||
--diffusion \
|
--diffusion \
|
||||||
--upscaling \
|
--upscaling \
|
||||||
--correction \
|
--correction \
|
||||||
--token=${HF_TOKEN:-}
|
--extras=${ONNX_WEB_EXTRA_MODELS:-extras.json} \
|
||||||
|
--token=${HF_TOKEN:-} \
|
||||||
|
--half
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||||
|
|
|
@ -40,6 +40,7 @@ def load_stable_diffusion(
|
||||||
model_path,
|
model_path,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -50,6 +51,7 @@ def load_stable_diffusion(
|
||||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
|
|
|
@ -3,6 +3,7 @@ from os import path
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
@ -164,6 +165,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, "to"):
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
|
|
Loading…
Reference in New Issue