1
0
Fork 0

lint(api): split out pipeline code

This commit is contained in:
Sean Sube 2023-01-15 18:46:00 -06:00
parent 806503c709
commit c7c3645466
2 changed files with 54 additions and 46 deletions

50
api/onnx_web/pipeline.py Normal file
View File

@ -0,0 +1,50 @@
from diffusers import (
DiffusionPipeline,
)
import numpy as np
last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None
# from https://www.travelneil.com/stable-diffusion-updates.html
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
latents_shape = (1, 4, height // 8, width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline')
pipe = last_pipeline_instance
else:
print('loading different pipeline')
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')
last_pipeline_scheduler = scheduler
return pipe

View File

@ -16,8 +16,6 @@ from diffusers import (
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
# types
DiffusionPipeline,
)
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS
@ -43,7 +41,10 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .pipeline import (
get_latents_from_seed,
load_pipeline,
)
from .upscale import (
upscale_gfpgan,
upscale_resrgan,
@ -67,9 +68,6 @@ num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))
# pipeline caching
available_models = []
config_params = {}
last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None
# pipeline params
platform_providers = {
@ -126,46 +124,6 @@ def get_model_path(model: str):
return safer_join(model_path, model)
# from https://www.travelneil.com/stable-diffusion-updates.html
def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
# 1 is batch size
latents_shape = (1, 4, height // 8, width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
print('reusing existing pipeline')
pipe = last_pipeline_instance
else:
print('loading different pipeline')
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')
last_pipeline_scheduler = scheduler
return pipe
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)