From 555ec5a644264ce220e756c7826248fb0bf04708 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 12 Apr 2023 20:03:00 -0500 Subject: [PATCH] turn controlnet into a select list and localize its label, fix name lookup on server --- api/onnx_web/diffusers/load.py | 4 +++- api/onnx_web/params.py | 2 +- gui/src/components/control/ModelControl.tsx | 5 +++-- gui/src/strings/de.ts | 1 + gui/src/strings/en.ts | 1 + gui/src/strings/es.ts | 1 + gui/src/strings/fr.ts | 1 + onnx-web.code-workspace | 1 + 8 files changed, 12 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index be6edf11..aec7f2fc 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -122,6 +122,7 @@ def load_pipeline( ): inversions = inversions or [] loras = loras or [] + control_key = control.name if control is not None else None torch_dtype = ( torch.float16 if "torch-fp16" in server.optimizations else torch.float32 @@ -133,6 +134,7 @@ def load_pipeline( device.device, device.provider, lpw, + control_key, inversions, loras, ) @@ -282,7 +284,7 @@ def load_pipeline( if control is not None: components["controlnet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( - path.join(server.model_path, "control", f"{control}.onnx"), + path.join(server.model_path, "control", f"{control.name}.onnx"), provider=device.ort_provider(), sess_options=device.sess_options(), ) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index a8d31571..6643d35c 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -210,7 +210,7 @@ class ImageParams: "lpw": self.lpw, "eta": self.eta, "batch": self.batch, - "control": self.control, + "control": self.control.name, } def with_args(self, **kwargs): diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 891e7b95..1fa578dc 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -116,7 +116,7 @@ export function ModelControl() { }); }} /> - result.networks.filter((network) => network.type === 'control').map((network) => network.name), }} - onSelect={(control) => { + value={params.control} + onChange={(control) => { setModel({ control, }); diff --git a/gui/src/strings/de.ts b/gui/src/strings/de.ts index 0dcd148f..e361540b 100644 --- a/gui/src/strings/de.ts +++ b/gui/src/strings/de.ts @@ -70,6 +70,7 @@ export const I18N_STRINGS_DE = { 'none': 'Keiner', }, modelType: { + control: '', correction: 'Korrekturmodelle', diffusion: 'Diffusionsmodelle', inversion: '', diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index f21e9daf..c00584c8 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -111,6 +111,7 @@ export const I18N_STRINGS_EN = { 'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6', }, modelType: { + control: 'ControlNet', correction: 'Correction Model', diffusion: 'Diffusion Model', inversion: 'Textual Inversion', diff --git a/gui/src/strings/es.ts b/gui/src/strings/es.ts index 71ce3921..018b1b78 100644 --- a/gui/src/strings/es.ts +++ b/gui/src/strings/es.ts @@ -70,6 +70,7 @@ export const I18N_STRINGS_ES = { 'none': 'Ninguno', }, modelType: { + control: '', correction: 'Modelo de corrección', diffusion: 'Modelo de difusión', inversion: '', diff --git a/gui/src/strings/fr.ts b/gui/src/strings/fr.ts index 7fbd4e0c..dfc8071a 100644 --- a/gui/src/strings/fr.ts +++ b/gui/src/strings/fr.ts @@ -70,6 +70,7 @@ export const I18N_STRINGS_FR = { 'none': '', }, modelType: { + control: '', correction: 'modèle de correction', diffusion: 'modèle de diffusion', inversion: '', diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 9cd22668..646eda0c 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -19,6 +19,7 @@ "ckpt", "codebook", "codeformer", + "controlnet", "CUDA", "ddim", "ddpm",