lint(api): split out pipeline code
This commit is contained in:
parent
806503c709
commit
c7c3645466
|
@ -0,0 +1,50 @@
|
|||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None, None)
|
||||
last_pipeline_scheduler = None
|
||||
|
||||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||
|
||||
|
||||
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||
# 1 is batch size
|
||||
latents_shape = (1, 4, height // 8, width // 8)
|
||||
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
||||
rng = np.random.default_rng(seed)
|
||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||
return image_latents
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (pipeline, model, provider)
|
||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||
print('reusing existing pipeline')
|
||||
pipe = last_pipeline_instance
|
||||
else:
|
||||
print('loading different pipeline')
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
provider=provider,
|
||||
safety_checker=None,
|
||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||
)
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
pipe.scheduler = scheduler.from_pretrained(
|
||||
model, subfolder='scheduler')
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
return pipe
|
|
@ -16,8 +16,6 @@ from diffusers import (
|
|||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
# types
|
||||
DiffusionPipeline,
|
||||
)
|
||||
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||
from flask_cors import CORS
|
||||
|
@ -43,7 +41,10 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
|
||||
from .pipeline import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from .upscale import (
|
||||
upscale_gfpgan,
|
||||
upscale_resrgan,
|
||||
|
@ -67,9 +68,6 @@ num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))
|
|||
# pipeline caching
|
||||
available_models = []
|
||||
config_params = {}
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_options = (None, None, None)
|
||||
last_pipeline_scheduler = None
|
||||
|
||||
# pipeline params
|
||||
platform_providers = {
|
||||
|
@ -126,46 +124,6 @@ def get_model_path(model: str):
|
|||
return safer_join(model_path, model)
|
||||
|
||||
|
||||
# from https://www.travelneil.com/stable-diffusion-updates.html
|
||||
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||
# 1 is batch size
|
||||
latents_shape = (1, 4, height // 8, width // 8)
|
||||
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
|
||||
rng = np.random.default_rng(seed)
|
||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||
return image_latents
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
||||
options = (pipeline, model, provider)
|
||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||
print('reusing existing pipeline')
|
||||
pipe = last_pipeline_instance
|
||||
else:
|
||||
print('loading different pipeline')
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
provider=provider,
|
||||
safety_checker=None,
|
||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||
)
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
pipe.scheduler = scheduler.from_pretrained(
|
||||
model, subfolder='scheduler')
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', bundle_path), filename)
|
||||
|
||||
|
|
Loading…
Reference in New Issue