1
0
Fork 0

feat: make pipeline type a request parameter

This commit is contained in:
Sean Sube 2023-04-12 22:58:48 -05:00
parent b2556809e9
commit 2af1530a7e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
18 changed files with 136 additions and 104 deletions

View File

@ -33,37 +33,23 @@ def blend_controlnet(
pipe = load_pipeline(
server,
OnnxStableDiffusionControlNetPipeline,
"controlnet",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength, # TODO: ControlNet strength
callback=callback,
)
output = result.images[0]

View File

@ -33,13 +33,12 @@ def blend_img2img(
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.pipeline,
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(

View File

@ -75,10 +75,9 @@ def blend_inpaint(
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(

View File

@ -35,13 +35,12 @@ def blend_pix2pix(
pipe = load_pipeline(
server,
OnnxStableDiffusionInstructPix2PixPipeline,
"pix2pix",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(

View File

@ -39,14 +39,13 @@ def source_txt2img(
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
"txt2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(

View File

@ -75,13 +75,12 @@ def upscale_outpaint(
latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
"inpaint",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(

View File

@ -19,15 +19,15 @@ from diffusers import (
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
OnnxRuntimeModel,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionPipeline,
PNDMScheduler,
StableDiffusionPipeline,
)
from onnx import load_model
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
from ..diffusers.utils import expand_prompt
try:
from diffusers import DEISMultistepScheduler
except ImportError:
@ -38,8 +38,13 @@ try:
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..diffusers.lpw_stable_diffusion_onnx import OnnxStableDiffusionLongPromptWeightingPipeline
from ..diffusers.utils import expand_prompt
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
@ -49,6 +54,15 @@ logger = getLogger(__name__)
latent_channels = 4
latent_factor = 8
available_pipelines = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": OnnxStableDiffusionImg2ImgPipeline,
"inpaint": OnnxStableDiffusionInpaintPipeline,
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
"txt2img": OnnxStableDiffusionPipeline,
}
pipeline_schedulers = {
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
@ -68,8 +82,12 @@ pipeline_schedulers = {
}
def get_pipeline_schedulers():
return pipeline_schedulers
def get_available_pipelines() -> List[str]:
return list(available_pipelines.keys())
def get_pipeline_schedulers() -> List[str]:
return list(pipeline_schedulers.keys())
def get_scheduler_name(scheduler: Any) -> Optional[str]:
@ -111,11 +129,10 @@ def get_tile_latents(
def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
pipeline: str,
model: str,
scheduler_name: str,
device: DeviceParams,
lpw: bool,
control: Optional[str] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
@ -129,7 +146,7 @@ def load_pipeline(
)
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = (
pipeline.__name__,
pipeline,
model,
device.device,
device.provider,
@ -170,11 +187,6 @@ def load_pipeline(
logger.debug("unloading previous diffusion pipeline")
run_gc([device])
if lpw:
custom_pipeline = "./onnx_web/diffusers/lpw_stable_diffusion_onnx.py"
else:
custom_pipeline = None
logger.debug("loading new diffusion pipeline from %s", model)
components = {
"scheduler": scheduler_type.from_pretrained(
@ -281,6 +293,7 @@ def load_pipeline(
)
)
# ControlNet component
if control is not None:
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -290,9 +303,10 @@ def load_pipeline(
)
)
pipe = pipeline.from_pretrained(
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
custom_pipeline=custom_pipeline,
provider=device.ort_provider(),
sess_options=device.sess_options(),
revision="onnx",
@ -306,12 +320,12 @@ def load_pipeline(
optimize_pipeline(server, pipe)
# TODO: CPU VAE, etc
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())
# monkey-patch pipeline
if not lpw:
patch_pipeline(server, pipe, pipeline)
patch_pipeline(server, pipe, pipeline)
server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])

View File

@ -47,17 +47,16 @@ def run_txt2img_pipeline(
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
"txt2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
)
progress = job.get_progress_callback()
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
@ -117,11 +116,10 @@ def run_txt2img_pipeline(
# load img2img pipeline once
highres_pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
"img2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
)
@ -153,7 +151,7 @@ def run_txt2img_pipeline(
callback=highres_progress,
)
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
@ -233,18 +231,16 @@ def run_img2img_pipeline(
pipe = load_pipeline(
server,
# OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionControlNetPipeline,
"img2img",
params.model,
params.scheduler,
job.get_device(),
params.lpw,
control=params.control,
inversions=inversions,
loras=loras,
)
progress = job.get_progress_callback()
if params.lpw:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(

View File

@ -75,13 +75,13 @@ def make_output_name(
hash_value(sha, mode)
hash_value(sha, params.model)
hash_value(sha, params.pipeline)
hash_value(sha, params.scheduler)
hash_value(sha, params.prompt)
hash_value(sha, params.negative_prompt)
hash_value(sha, params.cfg)
hash_value(sha, params.seed)
hash_value(sha, params.steps)
hash_value(sha, params.lpw)
hash_value(sha, params.eta)
hash_value(sha, params.batch)
hash_value(sha, size.width)

View File

@ -161,6 +161,7 @@ class DeviceParams:
class ImageParams:
model: str
pipeline: str
scheduler: str
prompt: str
cfg: float
@ -175,39 +176,42 @@ class ImageParams:
def __init__(
self,
model: str,
pipeline: str,
scheduler: str,
prompt: str,
cfg: float,
steps: int,
seed: int,
negative_prompt: Optional[str] = None,
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
control: Optional[str] = None,
) -> None:
self.model = model
self.pipeline = pipeline
self.scheduler = scheduler
self.prompt = prompt
self.negative_prompt = negative_prompt
self.cfg = cfg
self.seed = seed
self.steps = steps
self.lpw = lpw or False
self.eta = eta
self.batch = batch
self.control = control
def lpw(self):
return self.pipeline == "lpw"
def tojson(self) -> Dict[str, Optional[Param]]:
return {
"model": self.model,
"pipeline": self.pipeline,
"scheduler": self.scheduler,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"cfg": self.cfg,
"seed": self.seed,
"steps": self.steps,
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
"control": self.control.name,
@ -216,13 +220,13 @@ class ImageParams:
def with_args(self, **kwargs):
return ImageParams(
kwargs.get("model", self.model),
kwargs.get("pipeline", self.pipeline),
kwargs.get("scheduler", self.scheduler),
kwargs.get("prompt", self.prompt),
kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("cfg", self.cfg),
kwargs.get("steps", self.steps),
kwargs.get("seed", self.seed),
kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
kwargs.get("control", self.control),

View File

@ -4,7 +4,7 @@ from typing import Tuple
import numpy as np
from flask import request
from ..diffusers.load import pipeline_schedulers
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers, pipeline_schedulers
from ..params import (
Border,
DeviceParams,
@ -30,6 +30,7 @@ logger = getLogger(__name__)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
@ -42,15 +43,16 @@ def pipeline_from_request(
if platform.device == device_name:
device = platform
# pipeline stuff
lpw = get_not_empty(request.args, "lpw", "false") == "true"
# diffusion model
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(server, model)
scheduler = get_from_list(
request.args, "scheduler", list(pipeline_schedulers.keys())
)
control = get_from_list(request.args, "control", get_network_models())
# pipeline stuff
pipeline = get_from_list(request.args, "pipeline", get_available_pipelines(), default_pipeline)
scheduler = get_from_list(
request.args, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
@ -110,11 +112,12 @@ def pipeline_from_request(
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
@ -125,6 +128,7 @@ def pipeline_from_request(
params = ImageParams(
model_path,
pipeline,
scheduler,
prompt,
cfg,

View File

@ -124,6 +124,14 @@
"max": 4,
"step": 1
},
"pipeline": {
"default": "",
"keys": [
"controlnet",
"lpw",
"pix2pix"
]
},
"platform": {
"default": "amd",
"keys": []

View File

@ -9,8 +9,7 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
[tool.mypy]
# ignore_missing_imports = true
exclude = [
"onnx_web.diffusers.lpw_stable_diffusion_onnx",
"onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale"
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
]
[[tool.mypy.overrides]]

View File

@ -122,6 +122,14 @@
"max": 4,
"step": 1
},
"pipeline": {
"default": "",
"keys": [
"controlnet",
"lpw",
"pix2pix"
]
},
"platform": {
"default": "amd",
"keys": []

View File

@ -13,6 +13,11 @@ export interface ModelParams {
*/
model: string;
/**
* Specialized pipeline to use.
*/
pipeline: string;
/**
* The hardware acceleration platform to use.
*/
@ -28,11 +33,6 @@ export interface ModelParams {
*/
correction: string;
/**
* Use the long prompt weighting pipeline.
*/
lpw: boolean;
/**
* ControlNet to be used.
*/
@ -270,6 +270,11 @@ export interface ApiClient {
*/
params(): Promise<ServerParams>;
/**
* Get the available pipelines.
*/
pipelines(): Promise<Array<string>>;
/**
* Get the available hardware acceleration platforms.
*/
@ -393,10 +398,10 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams):
*/
export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('model', params.model);
url.searchParams.append('pipeline', params.pipeline);
url.searchParams.append('platform', params.platform);
url.searchParams.append('upscaling', params.upscaling);
url.searchParams.append('correction', params.correction);
url.searchParams.append('lpw', String(params.lpw));
url.searchParams.append('control', params.control);
}
@ -453,6 +458,11 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async pipelines(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'pipelines');
const res = await f(path);
return await res.json() as Array<string>;
},
async platforms(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'platforms');
const res = await f(path);

View File

@ -50,6 +50,9 @@ export const LOCAL_CLIENT = {
async models() {
throw new NoServerError();
},
async pipelines() {
throw new NoServerError();
},
async platforms() {
throw new NoServerError();
},

View File

@ -1,9 +1,9 @@
import { mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControlLabel, Stack } from '@mui/material';
import { Stack } from '@mui/material';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useQuery } from '@tanstack/react-query';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
@ -24,6 +24,9 @@ export function ModelControl() {
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const pipelines = useQuery(['pipelines'], async () => client.pipelines(), {
staleTime: STALE_TIME,
});
const platforms = useQuery(['platforms'], async () => client.platforms(), {
staleTime: STALE_TIME,
});
@ -34,7 +37,6 @@ export function ModelControl() {
const tab = getTab(hash);
const current = state.getState();
switch (tab) {
case 'txt2img': {
const { prompt } = current.txt2img;
@ -133,17 +135,20 @@ export function ModelControl() {
/>
</Stack>
<Stack direction='row' spacing={2}>
<FormControlLabel
label={t('parameter.lpw')}
control={<Checkbox
checked={params.lpw}
value='check'
onChange={(event) => {
setModel({
lpw: params.lpw === false,
});
}}
/>}
<QueryList
id='pipeline'
labelKey='pipeline'
name={t('parameter.pipeline')}
query={{
result: pipelines,
}}
showEmpty
value={params.pipeline}
onChange={(pipeline) => {
setModel({
pipeline,
});
}}
/>
<QueryMenu
id='inversion'

View File

@ -502,8 +502,8 @@ export function createStateSlices(server: ServerParams) {
model: {
control: server.control.default,
correction: server.correction.default,
lpw: false,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
},