1
0
Fork 0
onnx-web/api/onnx_web/diffusion/load.py

122 lines
3.3 KiB
Python
Raw Normal View History

from logging import getLogger
from typing import Any, Optional, Tuple
import numpy as np
2023-02-05 13:53:26 +00:00
from diffusers import DiffusionPipeline
from ..params import DeviceParams, Size
from ..utils import run_gc
logger = getLogger(__name__)
last_pipeline_instance: Any = None
last_pipeline_options: Tuple[
Optional[DiffusionPipeline],
Optional[str],
Optional[str],
Optional[str],
Optional[bool],
] = (None, None, None, None, None)
last_pipeline_scheduler: Any = None
latent_channels = 4
latent_factor = 8
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
2023-02-05 13:53:26 +00:00
"""
From https://www.travelneil.com/stable-diffusion-updates.html.
This one needs to use np.random because of the return type.
2023-02-05 13:53:26 +00:00
"""
latents_shape = (
batch,
latent_channels,
size.height // latent_factor,
size.width // latent_factor,
)
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
2023-02-05 13:53:26 +00:00
def get_tile_latents(
full_latents: np.ndarray, dims: Tuple[int, int, int]
) -> np.ndarray:
x, y, tile = dims
t = tile // latent_factor
x = x // latent_factor
y = y // latent_factor
xt = x + t
yt = y + t
return full_latents[:, :, y:yt, x:xt]
2023-02-05 13:53:26 +00:00
def load_pipeline(
2023-02-05 23:55:04 +00:00
pipeline: DiffusionPipeline,
model: str,
scheduler: Any,
device: DeviceParams,
lpw: bool,
2023-02-05 13:53:26 +00:00
):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
2023-02-05 23:15:37 +00:00
options = (pipeline, model, device.device, device.provider, lpw)
2023-02-05 13:53:26 +00:00
if last_pipeline_instance is not None and last_pipeline_options == options:
logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance
else:
2023-02-05 13:53:26 +00:00
logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None
last_pipeline_scheduler = None
run_gc()
2023-02-05 23:15:37 +00:00
if lpw:
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
else:
custom_pipeline = None
2023-02-05 13:53:26 +00:00
logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
2023-02-05 13:53:26 +00:00
subfolder="scheduler",
)
pipe = pipeline.from_pretrained(
model,
2023-02-05 23:15:37 +00:00
custom_pipeline=custom_pipeline,
provider=device.provider,
provider_options=device.options,
revision="onnx",
safety_checker=None,
scheduler=scheduler,
)
2023-02-05 13:53:26 +00:00
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device)
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
2023-02-05 13:53:26 +00:00
logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
2023-02-05 13:53:26 +00:00
subfolder="scheduler",
)
2023-02-05 13:53:26 +00:00
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device)
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler
run_gc()
return pipe