2023-01-29 02:15:39 +00:00
|
|
|
from logging import getLogger
|
2023-02-06 04:48:07 +00:00
|
|
|
from typing import Any, Optional, Tuple
|
2023-01-29 02:15:39 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2023-02-05 13:53:26 +00:00
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
from ..params import DeviceParams, Size
|
|
|
|
from ..utils import run_gc
|
2023-01-29 02:15:39 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-02-06 04:48:07 +00:00
|
|
|
last_pipeline_instance: Any = None
|
|
|
|
last_pipeline_options: Tuple[
|
|
|
|
Optional[DiffusionPipeline],
|
|
|
|
Optional[str],
|
|
|
|
Optional[str],
|
|
|
|
Optional[str],
|
|
|
|
Optional[bool],
|
|
|
|
] = (None, None, None, None, None)
|
|
|
|
last_pipeline_scheduler: Any = None
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-02 04:21:22 +00:00
|
|
|
latent_channels = 4
|
|
|
|
latent_factor = 8
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-02 04:21:22 +00:00
|
|
|
|
|
|
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-05 19:43:33 +00:00
|
|
|
From https://www.travelneil.com/stable-diffusion-updates.html.
|
|
|
|
This one needs to use np.random because of the return type.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
|
|
|
latents_shape = (
|
|
|
|
batch,
|
|
|
|
latent_channels,
|
|
|
|
size.height // latent_factor,
|
|
|
|
size.width // latent_factor,
|
|
|
|
)
|
2023-01-29 02:15:39 +00:00
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
|
|
|
return image_latents
|
|
|
|
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def get_tile_latents(
|
|
|
|
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
|
|
|
) -> np.ndarray:
|
2023-01-29 16:31:22 +00:00
|
|
|
x, y, tile = dims
|
2023-02-02 04:21:22 +00:00
|
|
|
t = tile // latent_factor
|
|
|
|
x = x // latent_factor
|
|
|
|
y = y // latent_factor
|
2023-01-29 16:31:22 +00:00
|
|
|
xt = x + t
|
|
|
|
yt = y + t
|
|
|
|
|
2023-02-02 03:20:48 +00:00
|
|
|
return full_latents[:, :, y:yt, x:xt]
|
2023-01-29 16:31:22 +00:00
|
|
|
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def load_pipeline(
|
2023-02-05 23:55:04 +00:00
|
|
|
pipeline: DiffusionPipeline,
|
|
|
|
model: str,
|
|
|
|
scheduler: Any,
|
|
|
|
device: DeviceParams,
|
|
|
|
lpw: bool,
|
2023-02-05 13:53:26 +00:00
|
|
|
):
|
2023-01-29 02:15:39 +00:00
|
|
|
global last_pipeline_instance
|
|
|
|
global last_pipeline_scheduler
|
|
|
|
global last_pipeline_options
|
|
|
|
|
2023-02-05 23:15:37 +00:00
|
|
|
options = (pipeline, model, device.device, device.provider, lpw)
|
2023-02-05 13:53:26 +00:00
|
|
|
if last_pipeline_instance is not None and last_pipeline_options == options:
|
|
|
|
logger.debug("reusing existing diffusion pipeline")
|
2023-01-29 02:15:39 +00:00
|
|
|
pipe = last_pipeline_instance
|
|
|
|
else:
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("unloading previous diffusion pipeline")
|
2023-01-29 02:15:39 +00:00
|
|
|
last_pipeline_instance = None
|
|
|
|
last_pipeline_scheduler = None
|
2023-02-02 03:20:48 +00:00
|
|
|
run_gc()
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-05 23:15:37 +00:00
|
|
|
if lpw:
|
|
|
|
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
|
|
|
|
else:
|
|
|
|
custom_pipeline = None
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("loading new diffusion pipeline from %s", model)
|
2023-02-05 03:17:39 +00:00
|
|
|
scheduler = scheduler.from_pretrained(
|
|
|
|
model,
|
|
|
|
provider=device.provider,
|
2023-02-05 03:52:45 +00:00
|
|
|
provider_options=device.options,
|
2023-02-05 13:53:26 +00:00
|
|
|
subfolder="scheduler",
|
2023-02-05 03:17:39 +00:00
|
|
|
)
|
2023-01-29 02:15:39 +00:00
|
|
|
pipe = pipeline.from_pretrained(
|
|
|
|
model,
|
2023-02-05 23:15:37 +00:00
|
|
|
custom_pipeline=custom_pipeline,
|
2023-02-05 03:17:39 +00:00
|
|
|
provider=device.provider,
|
2023-02-05 03:52:45 +00:00
|
|
|
provider_options=device.options,
|
2023-02-05 22:01:11 +00:00
|
|
|
revision="onnx",
|
2023-01-29 02:15:39 +00:00
|
|
|
safety_checker=None,
|
2023-02-05 03:17:39 +00:00
|
|
|
scheduler=scheduler,
|
2023-01-29 02:15:39 +00:00
|
|
|
)
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
if device is not None and hasattr(pipe, "to"):
|
2023-02-09 04:35:54 +00:00
|
|
|
pipe = pipe.to(device.torch_device())
|
2023-01-29 02:15:39 +00:00
|
|
|
|
|
|
|
last_pipeline_instance = pipe
|
|
|
|
last_pipeline_options = options
|
|
|
|
last_pipeline_scheduler = scheduler
|
|
|
|
|
|
|
|
if last_pipeline_scheduler != scheduler:
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("loading new diffusion scheduler")
|
2023-01-29 02:15:39 +00:00
|
|
|
scheduler = scheduler.from_pretrained(
|
2023-02-05 03:17:39 +00:00
|
|
|
model,
|
|
|
|
provider=device.provider,
|
2023-02-05 03:52:45 +00:00
|
|
|
provider_options=device.options,
|
2023-02-05 13:53:26 +00:00
|
|
|
subfolder="scheduler",
|
2023-02-05 03:17:39 +00:00
|
|
|
)
|
2023-01-29 02:15:39 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
if device is not None and hasattr(scheduler, "to"):
|
2023-01-29 02:15:39 +00:00
|
|
|
scheduler = scheduler.to(device)
|
|
|
|
|
|
|
|
pipe.scheduler = scheduler
|
|
|
|
last_pipeline_scheduler = scheduler
|
2023-02-02 03:20:48 +00:00
|
|
|
run_gc()
|
2023-01-29 02:15:39 +00:00
|
|
|
|
|
|
|
return pipe
|