1
0
Fork 0

feat: show additional networks in client

This commit is contained in:
Sean Sube 2023-03-18 19:14:24 -05:00
parent e5862d178c
commit 2d112104fb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 94 additions and 40 deletions

18
api/onnx_web/models.py Normal file
View File

@ -0,0 +1,18 @@
from typing import Literal
NetworkType = Literal["inversion", "lora"]
class NetworkModel:
name: str
type: NetworkType
def __init__(self, name: str, type: NetworkType) -> None:
self.name = name
self.type = type
def tojson(self):
return {
"name": self.name,
"type": self.type,
}

View File

@ -39,8 +39,8 @@ from .load import (
get_correction_models, get_correction_models,
get_diffusion_models, get_diffusion_models,
get_extra_strings, get_extra_strings,
get_inversion_models,
get_mask_filters, get_mask_filters,
get_network_models,
get_noise_sources, get_noise_sources,
get_upscaling_models, get_upscaling_models,
) )
@ -111,7 +111,7 @@ def list_models(context: ServerContext):
{ {
"correction": get_correction_models(), "correction": get_correction_models(),
"diffusion": get_diffusion_models(), "diffusion": get_diffusion_models(),
"inversion": get_inversion_models(), "networks": get_network_models(),
"upscaling": get_upscaling_models(), "upscaling": get_upscaling_models(),
} }
) )

View File

@ -2,7 +2,7 @@ from functools import cmp_to_key
from glob import glob from glob import glob
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
import torch import torch
import yaml import yaml
@ -20,6 +20,7 @@ from ..image import ( # mask filters; noise sources
noise_source_normal, noise_source_normal,
noise_source_uniform, noise_source_uniform,
) )
from ..models import NetworkModel
from ..params import DeviceParams from ..params import DeviceParams
from ..torch_before_ort import get_available_providers from ..torch_before_ort import get_available_providers
from ..utils import merge from ..utils import merge
@ -58,7 +59,7 @@ available_platforms: List[DeviceParams] = []
# loaded from model_path # loaded from model_path
correction_models: List[str] = [] correction_models: List[str] = []
diffusion_models: List[str] = [] diffusion_models: List[str] = []
inversion_models: List[str] = [] network_models: List[NetworkModel] = []
upscaling_models: List[str] = [] upscaling_models: List[str] = []
# Loaded from extra_models # Loaded from extra_models
@ -81,8 +82,8 @@ def get_diffusion_models():
return diffusion_models return diffusion_models
def get_inversion_models(): def get_network_models():
return inversion_models return network_models
def get_upscaling_models(): def get_upscaling_models():
@ -184,10 +185,12 @@ def load_extras(context: ServerContext):
extra_strings = strings extra_strings = strings
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: def list_model_globs(
context: ServerContext, globs: List[str], base_path: Optional[str] = None
) -> List[str]:
models = [] models = []
for pattern in globs: for pattern in globs:
pattern_path = path.join(context.model_path, pattern) pattern_path = path.join(base_path or context.model_path, pattern)
logger.debug("loading models from %s", pattern_path) logger.debug("loading models from %s", pattern_path)
models.extend([get_model_name(f) for f in glob(pattern_path)]) models.extend([get_model_name(f) for f in glob(pattern_path)])
@ -200,9 +203,10 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
def load_models(context: ServerContext) -> None: def load_models(context: ServerContext) -> None:
global correction_models global correction_models
global diffusion_models global diffusion_models
global inversion_models global network_models
global upscaling_models global upscaling_models
# main categories
diffusion_models = list_model_globs( diffusion_models = list_model_globs(
context, context,
[ [
@ -220,14 +224,6 @@ def load_models(context: ServerContext) -> None:
) )
logger.debug("loaded correction models from disk: %s", correction_models) logger.debug("loaded correction models from disk: %s", correction_models)
inversion_models = list_model_globs(
context,
[
"inversion-*",
],
)
logger.debug("loaded inversion models from disk: %s", inversion_models)
upscaling_models = list_model_globs( upscaling_models = list_model_globs(
context, context,
[ [
@ -236,6 +232,29 @@ def load_models(context: ServerContext) -> None:
) )
logger.debug("loaded upscaling models from disk: %s", upscaling_models) logger.debug("loaded upscaling models from disk: %s", upscaling_models)
# additional networks
inversion_models = list_model_globs(
context,
[
"*",
],
base_path=path.join(context.model_path, "inversion"),
)
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models]
)
lora_models = list_model_globs(
context,
[
"*",
],
base_path=path.join(context.model_path, "lora"),
)
logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
def load_params(context: ServerContext) -> None: def load_params(context: ServerContext) -> None:
global config_params global config_params

View File

@ -1,5 +1,5 @@
{ {
"version": "0.8.1", "version": "0.9.0",
"batch": { "batch": {
"default": 1, "default": 1,
"min": 1, "min": 1,

View File

@ -181,13 +181,20 @@ export interface ReadyResponse {
ready: boolean; ready: boolean;
} }
export interface NetworkModel {
name: string;
type: 'inversion' | 'lora';
// TODO: add token
// TODO: add layer/token count
}
/** /**
* List of available models. * List of available models.
*/ */
export interface ModelsResponse { export interface ModelsResponse {
diffusion: Array<string>;
correction: Array<string>; correction: Array<string>;
inversion: Array<string>; diffusion: Array<string>;
networks: Array<NetworkModel>;
upscaling: Array<string>; upscaling: Array<string>;
} }

View File

@ -42,7 +42,7 @@ export function ModelControl() {
/> />
<QueryList <QueryList
id='diffusion' id='diffusion'
labelKey='model' labelKey='diffusion'
name={t('modelType.diffusion')} name={t('modelType.diffusion')}
query={{ query={{
result: models, result: models,
@ -55,25 +55,9 @@ export function ModelControl() {
}); });
}} }}
/> />
<QueryList
id='inversion'
labelKey='model'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.inversion,
}}
showEmpty={true}
value={params.inversion}
onChange={(inversion) => {
setModel({
inversion,
});
}}
/>
<QueryList <QueryList
id='upscaling' id='upscaling'
labelKey='model' labelKey='upscaling'
name={t('modelType.upscaling')} name={t('modelType.upscaling')}
query={{ query={{
result: models, result: models,
@ -88,7 +72,7 @@ export function ModelControl() {
/> />
<QueryList <QueryList
id='correction' id='correction'
labelKey='model' labelKey='correction'
name={t('modelType.correction')} name={t('modelType.correction')}
query={{ query={{
result: models, result: models,
@ -113,5 +97,31 @@ export function ModelControl() {
}} }}
/>} />}
/> />
<QueryList
id='inversion'
labelKey='inversion'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
value={params.correction}
onChange={(correction) => {
// noop
}}
/>
<QueryList
id='lora'
labelKey='lora'
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
value={params.correction}
onChange={(correction) => {
// noop
}}
/>
</Stack>; </Stack>;
} }

View File

@ -78,7 +78,7 @@ export interface Config<T = ClientParams> {
} }
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png'; export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const PARAM_VERSION = '>=0.4.0'; export const PARAM_VERSION = '>=0.9.0';
export const STALE_TIME = 300_000; // 5 minutes export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds export const POLL_TIME = 5_000; // 5 seconds