test half precision on all diffusion models
This commit is contained in:
parent
1f9efb433a
commit
565873b3ae
|
@ -21,7 +21,9 @@ python3 -m onnx_web.convert \
|
|||
--diffusion \
|
||||
--upscaling \
|
||||
--correction \
|
||||
--token=${HF_TOKEN:-}
|
||||
--extras=${ONNX_WEB_EXTRA_MODELS:-extras.json} \
|
||||
--token=${HF_TOKEN:-} \
|
||||
--half
|
||||
|
||||
echo "Launching API server..."
|
||||
flask --app='onnx_web.main:run' run --host=0.0.0.0
|
||||
|
|
|
@ -40,6 +40,7 @@ def load_stable_diffusion(
|
|||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
|
@ -50,6 +51,7 @@ def load_stable_diffusion(
|
|||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
provider=device.provider,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
|
|
|
@ -3,6 +3,7 @@ from os import path
|
|||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
|
@ -164,6 +165,7 @@ def load_pipeline(
|
|||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
|
|
Loading…
Reference in New Issue