diff --git a/api/onnx_web/models.py b/api/onnx_web/models.py new file mode 100644 index 00000000..dcd43c25 --- /dev/null +++ b/api/onnx_web/models.py @@ -0,0 +1,18 @@ +from typing import Literal + +NetworkType = Literal["inversion", "lora"] + + +class NetworkModel: + name: str + type: NetworkType + + def __init__(self, name: str, type: NetworkType) -> None: + self.name = name + self.type = type + + def tojson(self): + return { + "name": self.name, + "type": self.type, + } diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index b98d952c..d3978ea6 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -39,8 +39,8 @@ from .load import ( get_correction_models, get_diffusion_models, get_extra_strings, - get_inversion_models, get_mask_filters, + get_network_models, get_noise_sources, get_upscaling_models, ) @@ -111,7 +111,7 @@ def list_models(context: ServerContext): { "correction": get_correction_models(), "diffusion": get_diffusion_models(), - "inversion": get_inversion_models(), + "networks": get_network_models(), "upscaling": get_upscaling_models(), } ) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 6a0ce060..4803ffa2 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -2,7 +2,7 @@ from functools import cmp_to_key from glob import glob from logging import getLogger from os import path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch import yaml @@ -20,6 +20,7 @@ from ..image import ( # mask filters; noise sources noise_source_normal, noise_source_uniform, ) +from ..models import NetworkModel from ..params import DeviceParams from ..torch_before_ort import get_available_providers from ..utils import merge @@ -58,7 +59,7 @@ available_platforms: List[DeviceParams] = [] # loaded from model_path correction_models: List[str] = [] diffusion_models: List[str] = [] -inversion_models: List[str] = [] +network_models: List[NetworkModel] = [] upscaling_models: List[str] = [] # Loaded from extra_models @@ -81,8 +82,8 @@ def get_diffusion_models(): return diffusion_models -def get_inversion_models(): - return inversion_models +def get_network_models(): + return network_models def get_upscaling_models(): @@ -184,10 +185,12 @@ def load_extras(context: ServerContext): extra_strings = strings -def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: +def list_model_globs( + context: ServerContext, globs: List[str], base_path: Optional[str] = None +) -> List[str]: models = [] for pattern in globs: - pattern_path = path.join(context.model_path, pattern) + pattern_path = path.join(base_path or context.model_path, pattern) logger.debug("loading models from %s", pattern_path) models.extend([get_model_name(f) for f in glob(pattern_path)]) @@ -200,9 +203,10 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: def load_models(context: ServerContext) -> None: global correction_models global diffusion_models - global inversion_models + global network_models global upscaling_models + # main categories diffusion_models = list_model_globs( context, [ @@ -220,14 +224,6 @@ def load_models(context: ServerContext) -> None: ) logger.debug("loaded correction models from disk: %s", correction_models) - inversion_models = list_model_globs( - context, - [ - "inversion-*", - ], - ) - logger.debug("loaded inversion models from disk: %s", inversion_models) - upscaling_models = list_model_globs( context, [ @@ -236,6 +232,29 @@ def load_models(context: ServerContext) -> None: ) logger.debug("loaded upscaling models from disk: %s", upscaling_models) + # additional networks + inversion_models = list_model_globs( + context, + [ + "*", + ], + base_path=path.join(context.model_path, "inversion"), + ) + logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) + network_models.extend( + [NetworkModel(model, "inversion") for model in inversion_models] + ) + + lora_models = list_model_globs( + context, + [ + "*", + ], + base_path=path.join(context.model_path, "lora"), + ) + logger.debug("loaded LoRA models from disk: %s", lora_models) + network_models.extend([NetworkModel(model, "lora") for model in lora_models]) + def load_params(context: ServerContext) -> None: global config_params diff --git a/api/params.json b/api/params.json index 957b4ed6..d2ecccdb 100644 --- a/api/params.json +++ b/api/params.json @@ -1,5 +1,5 @@ { - "version": "0.8.1", + "version": "0.9.0", "batch": { "default": 1, "min": 1, diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 087f54b6..c220c2db 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -181,13 +181,20 @@ export interface ReadyResponse { ready: boolean; } +export interface NetworkModel { + name: string; + type: 'inversion' | 'lora'; + // TODO: add token + // TODO: add layer/token count +} + /** * List of available models. */ export interface ModelsResponse { - diffusion: Array; correction: Array; - inversion: Array; + diffusion: Array; + networks: Array; upscaling: Array; } diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index a5e40a35..25a6e661 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -42,7 +42,7 @@ export function ModelControl() { /> - result.inversion, - }} - showEmpty={true} - value={params.inversion} - onChange={(inversion) => { - setModel({ - inversion, - }); - }} - /> } /> + result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), + }} + value={params.correction} + onChange={(correction) => { + // noop + }} + /> + result.networks.filter((network) => network.type === 'lora').map((network) => network.name), + }} + value={params.correction} + onChange={(correction) => { + // noop + }} + /> ; } diff --git a/gui/src/config.ts b/gui/src/config.ts index cadd6910..9c26a754 100644 --- a/gui/src/config.ts +++ b/gui/src/config.ts @@ -78,7 +78,7 @@ export interface Config { } export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png'; -export const PARAM_VERSION = '>=0.4.0'; +export const PARAM_VERSION = '>=0.9.0'; export const STALE_TIME = 300_000; // 5 minutes export const POLL_TIME = 5_000; // 5 seconds