diff --git a/api/onnx_web/chain/blend_controlnet.py b/api/onnx_web/chain/blend_controlnet.py index 37eea298..1ece0005 100644 --- a/api/onnx_web/chain/blend_controlnet.py +++ b/api/onnx_web/chain/blend_controlnet.py @@ -33,37 +33,23 @@ def blend_controlnet( pipe = load_pipeline( server, - OnnxStableDiffusionControlNetPipeline, + "controlnet", params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: - logger.debug("using LPW pipeline for img2img") - rng = torch.manual_seed(params.seed) - result = pipe.img2img( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - strength=params.strength, - callback=callback, - ) - else: - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - strength=params.strength, - callback=callback, - ) + + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + generator=rng, + guidance_scale=params.cfg, + image=source, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + strength=params.strength, # TODO: ControlNet strength + callback=callback, + ) output = result.images[0] diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index bf423e64..dc881f78 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -33,13 +33,12 @@ def blend_img2img( pipe = load_pipeline( server, - OnnxStableDiffusionImg2ImgPipeline, + params.pipeline, params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) result = pipe.img2img( diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 4220364c..58b13de3 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -75,10 +75,9 @@ def blend_inpaint( params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for inpaint") rng = torch.manual_seed(params.seed) result = pipe.inpaint( diff --git a/api/onnx_web/chain/blend_pix2pix.py b/api/onnx_web/chain/blend_pix2pix.py index 7a1de781..194dd906 100644 --- a/api/onnx_web/chain/blend_pix2pix.py +++ b/api/onnx_web/chain/blend_pix2pix.py @@ -35,13 +35,12 @@ def blend_pix2pix( pipe = load_pipeline( server, - OnnxStableDiffusionInstructPix2PixPipeline, + "pix2pix", params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) result = pipe.img2img( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 58220e14..b20439bf 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -39,14 +39,13 @@ def source_txt2img( latents = get_latents_from_seed(params.seed, size) pipe = load_pipeline( server, - OnnxStableDiffusionPipeline, + "txt2img", params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) result = pipe.text2img( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 39167bc1..0b9d29d9 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -75,13 +75,12 @@ def upscale_outpaint( latents = get_tile_latents(full_latents, dims) pipe = load_pipeline( server, - OnnxStableDiffusionInpaintPipeline, + "inpaint", params.model, params.scheduler, job.get_device(), - params.lpw, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for inpaint") rng = torch.manual_seed(params.seed) result = pipe.inpaint( diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 77e9b6f3..d84cc831 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -19,15 +19,15 @@ from diffusers import ( KDPM2DiscreteScheduler, LMSDiscreteScheduler, OnnxRuntimeModel, + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionPipeline, PNDMScheduler, StableDiffusionPipeline, ) from onnx import load_model from transformers import CLIPTokenizer -from ..constants import ONNX_MODEL -from ..diffusers.utils import expand_prompt - try: from diffusers import DEISMultistepScheduler except ImportError: @@ -38,8 +38,13 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler +from ..constants import ONNX_MODEL from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.textual_inversion import blend_textual_inversions +from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline +from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline +from ..diffusers.lpw_stable_diffusion_onnx import OnnxStableDiffusionLongPromptWeightingPipeline +from ..diffusers.utils import expand_prompt from ..params import DeviceParams, Size from ..server import ServerContext from ..utils import run_gc @@ -49,6 +54,15 @@ logger = getLogger(__name__) latent_channels = 4 latent_factor = 8 +available_pipelines = { + "controlnet": OnnxStableDiffusionControlNetPipeline, + "img2img": OnnxStableDiffusionImg2ImgPipeline, + "inpaint": OnnxStableDiffusionInpaintPipeline, + "lpw": OnnxStableDiffusionLongPromptWeightingPipeline, + "pix2pix": OnnxStableDiffusionInstructPix2PixPipeline, + "txt2img": OnnxStableDiffusionPipeline, +} + pipeline_schedulers = { "ddim": DDIMScheduler, "ddpm": DDPMScheduler, @@ -68,8 +82,12 @@ pipeline_schedulers = { } -def get_pipeline_schedulers(): - return pipeline_schedulers +def get_available_pipelines() -> List[str]: + return list(available_pipelines.keys()) + + +def get_pipeline_schedulers() -> List[str]: + return list(pipeline_schedulers.keys()) def get_scheduler_name(scheduler: Any) -> Optional[str]: @@ -111,11 +129,10 @@ def get_tile_latents( def load_pipeline( server: ServerContext, - pipeline: DiffusionPipeline, + pipeline: str, model: str, scheduler_name: str, device: DeviceParams, - lpw: bool, control: Optional[str] = None, inversions: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, @@ -129,7 +146,7 @@ def load_pipeline( ) logger.debug("using Torch dtype %s for pipeline", torch_dtype) pipe_key = ( - pipeline.__name__, + pipeline, model, device.device, device.provider, @@ -170,11 +187,6 @@ def load_pipeline( logger.debug("unloading previous diffusion pipeline") run_gc([device]) - if lpw: - custom_pipeline = "./onnx_web/diffusers/lpw_stable_diffusion_onnx.py" - else: - custom_pipeline = None - logger.debug("loading new diffusion pipeline from %s", model) components = { "scheduler": scheduler_type.from_pretrained( @@ -281,6 +293,7 @@ def load_pipeline( ) ) + # ControlNet component if control is not None: components["controlnet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -290,9 +303,10 @@ def load_pipeline( ) ) - pipe = pipeline.from_pretrained( + pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) + logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) + pipe = pipeline_class.from_pretrained( model, - custom_pipeline=custom_pipeline, provider=device.ort_provider(), sess_options=device.sess_options(), revision="onnx", @@ -306,12 +320,12 @@ def load_pipeline( optimize_pipeline(server, pipe) + # TODO: CPU VAE, etc if device is not None and hasattr(pipe, "to"): pipe = pipe.to(device.torch_str()) # monkey-patch pipeline - if not lpw: - patch_pipeline(server, pipe, pipeline) + patch_pipeline(server, pipe, pipeline) server.cache.set("diffusion", pipe_key, pipe) server.cache.set("scheduler", scheduler_key, components["scheduler"]) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index c71a04b7..891f4c34 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -47,17 +47,16 @@ def run_txt2img_pipeline( pipe = load_pipeline( server, - OnnxStableDiffusionPipeline, + "txt2img", params.model, params.scheduler, job.get_device(), - params.lpw, inversions, loras, ) progress = job.get_progress_callback() - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) result = pipe.text2img( @@ -117,11 +116,10 @@ def run_txt2img_pipeline( # load img2img pipeline once highres_pipe = load_pipeline( server, - OnnxStableDiffusionImg2ImgPipeline, + "img2img", params.model, params.scheduler, job.get_device(), - params.lpw, inversions, loras, ) @@ -153,7 +151,7 @@ def run_txt2img_pipeline( callback=highres_progress, ) - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for highres") rng = torch.manual_seed(params.seed) result = highres_pipe.img2img( @@ -233,18 +231,16 @@ def run_img2img_pipeline( pipe = load_pipeline( server, - # OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionControlNetPipeline, + "img2img", params.model, params.scheduler, job.get_device(), - params.lpw, control=params.control, inversions=inversions, loras=loras, ) progress = job.get_progress_callback() - if params.lpw: + if params.lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) result = pipe.img2img( diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 013882fc..5a9a7ed8 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -75,13 +75,13 @@ def make_output_name( hash_value(sha, mode) hash_value(sha, params.model) + hash_value(sha, params.pipeline) hash_value(sha, params.scheduler) hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) hash_value(sha, params.seed) hash_value(sha, params.steps) - hash_value(sha, params.lpw) hash_value(sha, params.eta) hash_value(sha, params.batch) hash_value(sha, size.width) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 6643d35c..5bc7d68c 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -161,6 +161,7 @@ class DeviceParams: class ImageParams: model: str + pipeline: str scheduler: str prompt: str cfg: float @@ -175,39 +176,42 @@ class ImageParams: def __init__( self, model: str, + pipeline: str, scheduler: str, prompt: str, cfg: float, steps: int, seed: int, negative_prompt: Optional[str] = None, - lpw: bool = False, eta: float = 0.0, batch: int = 1, control: Optional[str] = None, ) -> None: self.model = model + self.pipeline = pipeline self.scheduler = scheduler self.prompt = prompt self.negative_prompt = negative_prompt self.cfg = cfg self.seed = seed self.steps = steps - self.lpw = lpw or False self.eta = eta self.batch = batch self.control = control + def lpw(self): + return self.pipeline == "lpw" + def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model, + "pipeline": self.pipeline, "scheduler": self.scheduler, "prompt": self.prompt, "negative_prompt": self.negative_prompt, "cfg": self.cfg, "seed": self.seed, "steps": self.steps, - "lpw": self.lpw, "eta": self.eta, "batch": self.batch, "control": self.control.name, @@ -216,13 +220,13 @@ class ImageParams: def with_args(self, **kwargs): return ImageParams( kwargs.get("model", self.model), + kwargs.get("pipeline", self.pipeline), kwargs.get("scheduler", self.scheduler), kwargs.get("prompt", self.prompt), + kwargs.get("negative_prompt", self.negative_prompt), kwargs.get("cfg", self.cfg), kwargs.get("steps", self.steps), kwargs.get("seed", self.seed), - kwargs.get("negative_prompt", self.negative_prompt), - kwargs.get("lpw", self.lpw), kwargs.get("eta", self.eta), kwargs.get("batch", self.batch), kwargs.get("control", self.control), diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 526dfe49..275904f1 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -4,7 +4,7 @@ from typing import Tuple import numpy as np from flask import request -from ..diffusers.load import pipeline_schedulers +from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers, pipeline_schedulers from ..params import ( Border, DeviceParams, @@ -30,6 +30,7 @@ logger = getLogger(__name__) def pipeline_from_request( server: ServerContext, + default_pipeline: str, ) -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr @@ -42,15 +43,16 @@ def pipeline_from_request( if platform.device == device_name: device = platform - # pipeline stuff - lpw = get_not_empty(request.args, "lpw", "false") == "true" + # diffusion model model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(server, model) - scheduler = get_from_list( - request.args, "scheduler", list(pipeline_schedulers.keys()) - ) control = get_from_list(request.args, "control", get_network_models()) + # pipeline stuff + pipeline = get_from_list(request.args, "pipeline", get_available_pipelines(), default_pipeline) + scheduler = get_from_list( + request.args, "scheduler", get_pipeline_schedulers()) + if scheduler is None: scheduler = get_config_value("scheduler") @@ -110,11 +112,12 @@ def pipeline_from_request( seed = np.random.randint(np.iinfo(np.int32).max) logger.info( - "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", + "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", user, steps, scheduler, model_path, + pipeline, device or "any device", width, height, @@ -125,6 +128,7 @@ def pipeline_from_request( params = ImageParams( model_path, + pipeline, scheduler, prompt, cfg, diff --git a/api/params.json b/api/params.json index 7eaaaacb..9212fb2a 100644 --- a/api/params.json +++ b/api/params.json @@ -124,6 +124,14 @@ "max": 4, "step": 1 }, + "pipeline": { + "default": "", + "keys": [ + "controlnet", + "lpw", + "pix2pix" + ] + }, "platform": { "default": "amd", "keys": [] diff --git a/api/pyproject.toml b/api/pyproject.toml index aabd4334..53c7d5bf 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -9,8 +9,7 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion [tool.mypy] # ignore_missing_imports = true exclude = [ - "onnx_web.diffusers.lpw_stable_diffusion_onnx", - "onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale" + "onnx_web.diffusers.lpw_stable_diffusion_onnx" ] [[tool.mypy.overrides]] diff --git a/gui/examples/config.json b/gui/examples/config.json index fc907d3d..5514d239 100644 --- a/gui/examples/config.json +++ b/gui/examples/config.json @@ -122,6 +122,14 @@ "max": 4, "step": 1 }, + "pipeline": { + "default": "", + "keys": [ + "controlnet", + "lpw", + "pix2pix" + ] + }, "platform": { "default": "amd", "keys": [] diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 56710758..4b1ab699 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -13,6 +13,11 @@ export interface ModelParams { */ model: string; + /** + * Specialized pipeline to use. + */ + pipeline: string; + /** * The hardware acceleration platform to use. */ @@ -28,11 +33,6 @@ export interface ModelParams { */ correction: string; - /** - * Use the long prompt weighting pipeline. - */ - lpw: boolean; - /** * ControlNet to be used. */ @@ -270,6 +270,11 @@ export interface ApiClient { */ params(): Promise; + /** + * Get the available pipelines. + */ + pipelines(): Promise>; + /** * Get the available hardware acceleration platforms. */ @@ -393,10 +398,10 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): */ export function appendModelToURL(url: URL, params: ModelParams) { url.searchParams.append('model', params.model); + url.searchParams.append('pipeline', params.pipeline); url.searchParams.append('platform', params.platform); url.searchParams.append('upscaling', params.upscaling); url.searchParams.append('correction', params.correction); - url.searchParams.append('lpw', String(params.lpw)); url.searchParams.append('control', params.control); } @@ -453,6 +458,11 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, + async pipelines(): Promise> { + const path = makeApiUrl(root, 'settings', 'pipelines'); + const res = await f(path); + return await res.json() as Array; + }, async platforms(): Promise> { const path = makeApiUrl(root, 'settings', 'platforms'); const res = await f(path); diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 8dc931af..452a457d 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -50,6 +50,9 @@ export const LOCAL_CLIENT = { async models() { throw new NoServerError(); }, + async pipelines() { + throw new NoServerError(); + }, async platforms() { throw new NoServerError(); }, diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 1fa578dc..0f7e7162 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -1,9 +1,9 @@ import { mustExist } from '@apextoaster/js-utils'; -import { Checkbox, FormControlLabel, Stack } from '@mui/material'; +import { Stack } from '@mui/material'; +import { useQuery } from '@tanstack/react-query'; import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; -import { useQuery } from '@tanstack/react-query'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; @@ -24,6 +24,9 @@ export function ModelControl() { const models = useQuery(['models'], async () => client.models(), { staleTime: STALE_TIME, }); + const pipelines = useQuery(['pipelines'], async () => client.pipelines(), { + staleTime: STALE_TIME, + }); const platforms = useQuery(['platforms'], async () => client.platforms(), { staleTime: STALE_TIME, }); @@ -34,7 +37,6 @@ export function ModelControl() { const tab = getTab(hash); const current = state.getState(); - switch (tab) { case 'txt2img': { const { prompt } = current.txt2img; @@ -133,17 +135,20 @@ export function ModelControl() { /> - { - setModel({ - lpw: params.lpw === false, - }); - }} - />} + { + setModel({ + pipeline, + }); + }} />