diff --git a/api/onnx_web/chain/blend_controlnet.py b/api/onnx_web/chain/blend_controlnet.py index b5553b5e..37eea298 100644 --- a/api/onnx_web/chain/blend_controlnet.py +++ b/api/onnx_web/chain/blend_controlnet.py @@ -28,7 +28,7 @@ def blend_controlnet( params = params.with_args(**kwargs) source = stage_source or source logger.info( - "blending image using controlnet, %s steps: %s", params.steps, params.prompt + "blending image using ControlNet, %s steps: %s", params.steps, params.prompt ) pipe = load_pipeline( diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index ae5696f3..e4c18a6e 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -174,7 +174,6 @@ def convert_diffusion_diffusers( if is_torch_2_0: pipe_cnet.set_attn_processor(CrossAttnProcessor()) - cnet_path = output_path / "cnet" / ONNX_MODEL onnx_export( pipe_cnet, diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index f453b723..bf5e456b 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -116,14 +116,13 @@ def load_pipeline( scheduler_name: str, device: DeviceParams, lpw: bool, + control: Optional[str] = None, inversions: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, ): inversions = inversions or [] loras = loras or [] - controlnet = "canny" # TODO; from params - torch_dtype = ( torch.float16 if "torch-fp16" in server.optimizations else torch.float32 ) @@ -186,6 +185,8 @@ def load_pipeline( } text_encoder = None + + # Textual Inversion blending if inversions is not None and len(inversions) > 0: logger.debug("blending Textual Inversions from %s", inversions) inversion_names, inversion_weights = zip(*inversions) @@ -225,7 +226,7 @@ def load_pipeline( ) ) - # test LoRA blending + # LoRA blending if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ @@ -278,8 +279,12 @@ def load_pipeline( ) ) - if controlnet is not None: - components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet) + if control is not None: + components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model( + control, + provider=device.ort_provider(), + sess_options=device.sess_options(), + )) pipe = pipeline.from_pretrained( model, @@ -360,7 +365,7 @@ class UNetWrapper(object): self.server = server self.wrapped = wrapped - def __call__(self, sample=None, timestep=None, encoder_hidden_states=None): + def __call__(self, sample=None, timestep=None, encoder_hidden_states=None, **kwargs): global timestep_dtype timestep_dtype = timestep.dtype @@ -382,6 +387,7 @@ class UNetWrapper(object): sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, + **kwargs, ) def __getattr__(self, attr): @@ -393,7 +399,7 @@ class VAEWrapper(object): self.server = server self.wrapped = wrapped - def __call__(self, latent_sample=None): + def __call__(self, latent_sample=None, **kwargs): global timestep_dtype logger.trace("VAE parameter types: %s", latent_sample.dtype) @@ -401,7 +407,7 @@ class VAEWrapper(object): logger.info("converting VAE sample dtype") latent_sample = latent_sample.astype(timestep_dtype) - return self.wrapped(latent_sample=latent_sample) + return self.wrapped(latent_sample=latent_sample, **kwargs) def __getattr__(self, attr): return getattr(self.wrapped, attr) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index ef365f0d..c71a04b7 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -239,8 +239,9 @@ def run_img2img_pipeline( params.scheduler, job.get_device(), params.lpw, - inversions, - loras, + control=params.control, + inversions=inversions, + loras=loras, ) progress = job.get_progress_callback() if params.lpw: diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 9357aca1..a8d31571 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -160,6 +160,18 @@ class DeviceParams: class ImageParams: + model: str + scheduler: str + prompt: str + cfg: float + steps: int + seed: int + negative_prompt: Optional[str] + lpw: bool + eta: float + batch: int + control: Optional[str] + def __init__( self, model: str, @@ -172,6 +184,7 @@ class ImageParams: lpw: bool = False, eta: float = 0.0, batch: int = 1, + control: Optional[str] = None, ) -> None: self.model = model self.scheduler = scheduler @@ -183,6 +196,7 @@ class ImageParams: self.lpw = lpw or False self.eta = eta self.batch = batch + self.control = control def tojson(self) -> Dict[str, Optional[Param]]: return { @@ -196,6 +210,7 @@ class ImageParams: "lpw": self.lpw, "eta": self.eta, "batch": self.batch, + "control": self.control, } def with_args(self, **kwargs): @@ -210,6 +225,7 @@ class ImageParams: kwargs.get("lpw", self.lpw), kwargs.get("eta", self.eta), kwargs.get("batch", self.batch), + kwargs.get("control", self.control), ) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index d21de826..e730a459 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -42,6 +42,7 @@ def pipeline_from_request( device = platform # pipeline stuff + control = get_not_empty(request.args, "control", get_config_value("control")) lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(server, model) @@ -132,6 +133,7 @@ def pipeline_from_request( lpw=lpw, negative_prompt=negative_prompt, batch=batch, + control=control, ) size = Size(width, height) return (device, params, size) diff --git a/api/params.json b/api/params.json index dc986242..7eaaaacb 100644 --- a/api/params.json +++ b/api/params.json @@ -18,6 +18,10 @@ "max": 30, "step": 0.1 }, + "control": { + "default": "", + "keys": [] + }, "correction": { "default": "", "keys": [] diff --git a/gui/examples/config.json b/gui/examples/config.json index c6d3e936..fc907d3d 100644 --- a/gui/examples/config.json +++ b/gui/examples/config.json @@ -22,6 +22,10 @@ "max": 30, "step": 0.1 }, + "control": { + "default": "", + "keys": [] + }, "correction": { "default": "", "keys": [] diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 5ddde731..56710758 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -32,6 +32,11 @@ export interface ModelParams { * Use the long prompt weighting pipeline. */ lpw: boolean; + + /** + * ControlNet to be used. + */ + control: string; } /** @@ -191,7 +196,7 @@ export interface ReadyResponse { export interface NetworkModel { name: string; - type: 'inversion' | 'lora'; + type: 'control' | 'inversion' | 'lora'; // TODO: add token // TODO: add layer/token count } @@ -392,6 +397,7 @@ export function appendModelToURL(url: URL, params: ModelParams) { url.searchParams.append('upscaling', params.upscaling); url.searchParams.append('correction', params.correction); url.searchParams.append('lpw', String(params.lpw)); + url.searchParams.append('control', params.control); } /** diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 70a4f9d5..891e7b95 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -116,6 +116,20 @@ export function ModelControl() { }); }} /> + result.networks.filter((network) => network.type === 'control').map((network) => network.name), + }} + onSelect={(control) => { + setModel({ + control, + }); + }} + /> = (set) => ({ model: { + control: server.control.default, + correction: server.correction.default, + lpw: false, model: server.model.default, platform: server.platform.default, upscaling: server.upscaling.default, - correction: server.correction.default, - lpw: false, }, setModel(params) { set((prev) => ({