feat: show additional networks in client
This commit is contained in:
parent
e5862d178c
commit
2d112104fb
|
@ -0,0 +1,18 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
NetworkType = Literal["inversion", "lora"]
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModel:
|
||||||
|
name: str
|
||||||
|
type: NetworkType
|
||||||
|
|
||||||
|
def __init__(self, name: str, type: NetworkType) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def tojson(self):
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"type": self.type,
|
||||||
|
}
|
|
@ -39,8 +39,8 @@ from .load import (
|
||||||
get_correction_models,
|
get_correction_models,
|
||||||
get_diffusion_models,
|
get_diffusion_models,
|
||||||
get_extra_strings,
|
get_extra_strings,
|
||||||
get_inversion_models,
|
|
||||||
get_mask_filters,
|
get_mask_filters,
|
||||||
|
get_network_models,
|
||||||
get_noise_sources,
|
get_noise_sources,
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
)
|
)
|
||||||
|
@ -111,7 +111,7 @@ def list_models(context: ServerContext):
|
||||||
{
|
{
|
||||||
"correction": get_correction_models(),
|
"correction": get_correction_models(),
|
||||||
"diffusion": get_diffusion_models(),
|
"diffusion": get_diffusion_models(),
|
||||||
"inversion": get_inversion_models(),
|
"networks": get_network_models(),
|
||||||
"upscaling": get_upscaling_models(),
|
"upscaling": get_upscaling_models(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,7 +2,7 @@ from functools import cmp_to_key
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -20,6 +20,7 @@ from ..image import ( # mask filters; noise sources
|
||||||
noise_source_normal,
|
noise_source_normal,
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
|
from ..models import NetworkModel
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..torch_before_ort import get_available_providers
|
from ..torch_before_ort import get_available_providers
|
||||||
from ..utils import merge
|
from ..utils import merge
|
||||||
|
@ -58,7 +59,7 @@ available_platforms: List[DeviceParams] = []
|
||||||
# loaded from model_path
|
# loaded from model_path
|
||||||
correction_models: List[str] = []
|
correction_models: List[str] = []
|
||||||
diffusion_models: List[str] = []
|
diffusion_models: List[str] = []
|
||||||
inversion_models: List[str] = []
|
network_models: List[NetworkModel] = []
|
||||||
upscaling_models: List[str] = []
|
upscaling_models: List[str] = []
|
||||||
|
|
||||||
# Loaded from extra_models
|
# Loaded from extra_models
|
||||||
|
@ -81,8 +82,8 @@ def get_diffusion_models():
|
||||||
return diffusion_models
|
return diffusion_models
|
||||||
|
|
||||||
|
|
||||||
def get_inversion_models():
|
def get_network_models():
|
||||||
return inversion_models
|
return network_models
|
||||||
|
|
||||||
|
|
||||||
def get_upscaling_models():
|
def get_upscaling_models():
|
||||||
|
@ -184,10 +185,12 @@ def load_extras(context: ServerContext):
|
||||||
extra_strings = strings
|
extra_strings = strings
|
||||||
|
|
||||||
|
|
||||||
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
def list_model_globs(
|
||||||
|
context: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
models = []
|
models = []
|
||||||
for pattern in globs:
|
for pattern in globs:
|
||||||
pattern_path = path.join(context.model_path, pattern)
|
pattern_path = path.join(base_path or context.model_path, pattern)
|
||||||
logger.debug("loading models from %s", pattern_path)
|
logger.debug("loading models from %s", pattern_path)
|
||||||
|
|
||||||
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
||||||
|
@ -200,9 +203,10 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
||||||
def load_models(context: ServerContext) -> None:
|
def load_models(context: ServerContext) -> None:
|
||||||
global correction_models
|
global correction_models
|
||||||
global diffusion_models
|
global diffusion_models
|
||||||
global inversion_models
|
global network_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
|
||||||
|
# main categories
|
||||||
diffusion_models = list_model_globs(
|
diffusion_models = list_model_globs(
|
||||||
context,
|
context,
|
||||||
[
|
[
|
||||||
|
@ -220,14 +224,6 @@ def load_models(context: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||||
|
|
||||||
inversion_models = list_model_globs(
|
|
||||||
context,
|
|
||||||
[
|
|
||||||
"inversion-*",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
|
||||||
|
|
||||||
upscaling_models = list_model_globs(
|
upscaling_models = list_model_globs(
|
||||||
context,
|
context,
|
||||||
[
|
[
|
||||||
|
@ -236,6 +232,29 @@ def load_models(context: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||||
|
|
||||||
|
# additional networks
|
||||||
|
inversion_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
base_path=path.join(context.model_path, "inversion"),
|
||||||
|
)
|
||||||
|
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||||
|
network_models.extend(
|
||||||
|
[NetworkModel(model, "inversion") for model in inversion_models]
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
|
"*",
|
||||||
|
],
|
||||||
|
base_path=path.join(context.model_path, "lora"),
|
||||||
|
)
|
||||||
|
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||||
|
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext) -> None:
|
def load_params(context: ServerContext) -> None:
|
||||||
global config_params
|
global config_params
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"version": "0.8.1",
|
"version": "0.9.0",
|
||||||
"batch": {
|
"batch": {
|
||||||
"default": 1,
|
"default": 1,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
|
|
|
@ -181,13 +181,20 @@ export interface ReadyResponse {
|
||||||
ready: boolean;
|
ready: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface NetworkModel {
|
||||||
|
name: string;
|
||||||
|
type: 'inversion' | 'lora';
|
||||||
|
// TODO: add token
|
||||||
|
// TODO: add layer/token count
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List of available models.
|
* List of available models.
|
||||||
*/
|
*/
|
||||||
export interface ModelsResponse {
|
export interface ModelsResponse {
|
||||||
diffusion: Array<string>;
|
|
||||||
correction: Array<string>;
|
correction: Array<string>;
|
||||||
inversion: Array<string>;
|
diffusion: Array<string>;
|
||||||
|
networks: Array<NetworkModel>;
|
||||||
upscaling: Array<string>;
|
upscaling: Array<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ export function ModelControl() {
|
||||||
/>
|
/>
|
||||||
<QueryList
|
<QueryList
|
||||||
id='diffusion'
|
id='diffusion'
|
||||||
labelKey='model'
|
labelKey='diffusion'
|
||||||
name={t('modelType.diffusion')}
|
name={t('modelType.diffusion')}
|
||||||
query={{
|
query={{
|
||||||
result: models,
|
result: models,
|
||||||
|
@ -55,25 +55,9 @@ export function ModelControl() {
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<QueryList
|
|
||||||
id='inversion'
|
|
||||||
labelKey='model'
|
|
||||||
name={t('modelType.inversion')}
|
|
||||||
query={{
|
|
||||||
result: models,
|
|
||||||
selector: (result) => result.inversion,
|
|
||||||
}}
|
|
||||||
showEmpty={true}
|
|
||||||
value={params.inversion}
|
|
||||||
onChange={(inversion) => {
|
|
||||||
setModel({
|
|
||||||
inversion,
|
|
||||||
});
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<QueryList
|
<QueryList
|
||||||
id='upscaling'
|
id='upscaling'
|
||||||
labelKey='model'
|
labelKey='upscaling'
|
||||||
name={t('modelType.upscaling')}
|
name={t('modelType.upscaling')}
|
||||||
query={{
|
query={{
|
||||||
result: models,
|
result: models,
|
||||||
|
@ -88,7 +72,7 @@ export function ModelControl() {
|
||||||
/>
|
/>
|
||||||
<QueryList
|
<QueryList
|
||||||
id='correction'
|
id='correction'
|
||||||
labelKey='model'
|
labelKey='correction'
|
||||||
name={t('modelType.correction')}
|
name={t('modelType.correction')}
|
||||||
query={{
|
query={{
|
||||||
result: models,
|
result: models,
|
||||||
|
@ -113,5 +97,31 @@ export function ModelControl() {
|
||||||
}}
|
}}
|
||||||
/>}
|
/>}
|
||||||
/>
|
/>
|
||||||
|
<QueryList
|
||||||
|
id='inversion'
|
||||||
|
labelKey='inversion'
|
||||||
|
name={t('modelType.inversion')}
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||||
|
}}
|
||||||
|
value={params.correction}
|
||||||
|
onChange={(correction) => {
|
||||||
|
// noop
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<QueryList
|
||||||
|
id='lora'
|
||||||
|
labelKey='lora'
|
||||||
|
name={t('modelType.lora')}
|
||||||
|
query={{
|
||||||
|
result: models,
|
||||||
|
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||||
|
}}
|
||||||
|
value={params.correction}
|
||||||
|
onChange={(correction) => {
|
||||||
|
// noop
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</Stack>;
|
</Stack>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,7 @@ export interface Config<T = ClientParams> {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
||||||
export const PARAM_VERSION = '>=0.4.0';
|
export const PARAM_VERSION = '>=0.9.0';
|
||||||
|
|
||||||
export const STALE_TIME = 300_000; // 5 minutes
|
export const STALE_TIME = 300_000; // 5 minutes
|
||||||
export const POLL_TIME = 5_000; // 5 seconds
|
export const POLL_TIME = 5_000; // 5 seconds
|
||||||
|
|
Loading…
Reference in New Issue