feat: show additional networks in client
This commit is contained in:
parent
e5862d178c
commit
2d112104fb
|
@ -0,0 +1,18 @@
|
|||
from typing import Literal
|
||||
|
||||
NetworkType = Literal["inversion", "lora"]
|
||||
|
||||
|
||||
class NetworkModel:
|
||||
name: str
|
||||
type: NetworkType
|
||||
|
||||
def __init__(self, name: str, type: NetworkType) -> None:
|
||||
self.name = name
|
||||
self.type = type
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
}
|
|
@ -39,8 +39,8 @@ from .load import (
|
|||
get_correction_models,
|
||||
get_diffusion_models,
|
||||
get_extra_strings,
|
||||
get_inversion_models,
|
||||
get_mask_filters,
|
||||
get_network_models,
|
||||
get_noise_sources,
|
||||
get_upscaling_models,
|
||||
)
|
||||
|
@ -111,7 +111,7 @@ def list_models(context: ServerContext):
|
|||
{
|
||||
"correction": get_correction_models(),
|
||||
"diffusion": get_diffusion_models(),
|
||||
"inversion": get_inversion_models(),
|
||||
"networks": get_network_models(),
|
||||
"upscaling": get_upscaling_models(),
|
||||
}
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@ from functools import cmp_to_key
|
|||
from glob import glob
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
@ -20,6 +20,7 @@ from ..image import ( # mask filters; noise sources
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from ..models import NetworkModel
|
||||
from ..params import DeviceParams
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from ..utils import merge
|
||||
|
@ -58,7 +59,7 @@ available_platforms: List[DeviceParams] = []
|
|||
# loaded from model_path
|
||||
correction_models: List[str] = []
|
||||
diffusion_models: List[str] = []
|
||||
inversion_models: List[str] = []
|
||||
network_models: List[NetworkModel] = []
|
||||
upscaling_models: List[str] = []
|
||||
|
||||
# Loaded from extra_models
|
||||
|
@ -81,8 +82,8 @@ def get_diffusion_models():
|
|||
return diffusion_models
|
||||
|
||||
|
||||
def get_inversion_models():
|
||||
return inversion_models
|
||||
def get_network_models():
|
||||
return network_models
|
||||
|
||||
|
||||
def get_upscaling_models():
|
||||
|
@ -184,10 +185,12 @@ def load_extras(context: ServerContext):
|
|||
extra_strings = strings
|
||||
|
||||
|
||||
def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
||||
def list_model_globs(
|
||||
context: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||
) -> List[str]:
|
||||
models = []
|
||||
for pattern in globs:
|
||||
pattern_path = path.join(context.model_path, pattern)
|
||||
pattern_path = path.join(base_path or context.model_path, pattern)
|
||||
logger.debug("loading models from %s", pattern_path)
|
||||
|
||||
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
||||
|
@ -200,9 +203,10 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
|||
def load_models(context: ServerContext) -> None:
|
||||
global correction_models
|
||||
global diffusion_models
|
||||
global inversion_models
|
||||
global network_models
|
||||
global upscaling_models
|
||||
|
||||
# main categories
|
||||
diffusion_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
|
@ -220,14 +224,6 @@ def load_models(context: ServerContext) -> None:
|
|||
)
|
||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||
|
||||
inversion_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"inversion-*",
|
||||
],
|
||||
)
|
||||
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
||||
|
||||
upscaling_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
|
@ -236,6 +232,29 @@ def load_models(context: ServerContext) -> None:
|
|||
)
|
||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||
|
||||
# additional networks
|
||||
inversion_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"*",
|
||||
],
|
||||
base_path=path.join(context.model_path, "inversion"),
|
||||
)
|
||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||
network_models.extend(
|
||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
||||
)
|
||||
|
||||
lora_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"*",
|
||||
],
|
||||
base_path=path.join(context.model_path, "lora"),
|
||||
)
|
||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||
|
||||
|
||||
def load_params(context: ServerContext) -> None:
|
||||
global config_params
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"version": "0.8.1",
|
||||
"version": "0.9.0",
|
||||
"batch": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
|
|
|
@ -181,13 +181,20 @@ export interface ReadyResponse {
|
|||
ready: boolean;
|
||||
}
|
||||
|
||||
export interface NetworkModel {
|
||||
name: string;
|
||||
type: 'inversion' | 'lora';
|
||||
// TODO: add token
|
||||
// TODO: add layer/token count
|
||||
}
|
||||
|
||||
/**
|
||||
* List of available models.
|
||||
*/
|
||||
export interface ModelsResponse {
|
||||
diffusion: Array<string>;
|
||||
correction: Array<string>;
|
||||
inversion: Array<string>;
|
||||
diffusion: Array<string>;
|
||||
networks: Array<NetworkModel>;
|
||||
upscaling: Array<string>;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ export function ModelControl() {
|
|||
/>
|
||||
<QueryList
|
||||
id='diffusion'
|
||||
labelKey='model'
|
||||
labelKey='diffusion'
|
||||
name={t('modelType.diffusion')}
|
||||
query={{
|
||||
result: models,
|
||||
|
@ -55,25 +55,9 @@ export function ModelControl() {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='inversion'
|
||||
labelKey='model'
|
||||
name={t('modelType.inversion')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.inversion,
|
||||
}}
|
||||
showEmpty={true}
|
||||
value={params.inversion}
|
||||
onChange={(inversion) => {
|
||||
setModel({
|
||||
inversion,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='upscaling'
|
||||
labelKey='model'
|
||||
labelKey='upscaling'
|
||||
name={t('modelType.upscaling')}
|
||||
query={{
|
||||
result: models,
|
||||
|
@ -88,7 +72,7 @@ export function ModelControl() {
|
|||
/>
|
||||
<QueryList
|
||||
id='correction'
|
||||
labelKey='model'
|
||||
labelKey='correction'
|
||||
name={t('modelType.correction')}
|
||||
query={{
|
||||
result: models,
|
||||
|
@ -113,5 +97,31 @@ export function ModelControl() {
|
|||
}}
|
||||
/>}
|
||||
/>
|
||||
<QueryList
|
||||
id='inversion'
|
||||
labelKey='inversion'
|
||||
name={t('modelType.inversion')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
|
||||
}}
|
||||
value={params.correction}
|
||||
onChange={(correction) => {
|
||||
// noop
|
||||
}}
|
||||
/>
|
||||
<QueryList
|
||||
id='lora'
|
||||
labelKey='lora'
|
||||
name={t('modelType.lora')}
|
||||
query={{
|
||||
result: models,
|
||||
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
|
||||
}}
|
||||
value={params.correction}
|
||||
onChange={(correction) => {
|
||||
// noop
|
||||
}}
|
||||
/>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ export interface Config<T = ClientParams> {
|
|||
}
|
||||
|
||||
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
||||
export const PARAM_VERSION = '>=0.4.0';
|
||||
export const PARAM_VERSION = '>=0.9.0';
|
||||
|
||||
export const STALE_TIME = 300_000; // 5 minutes
|
||||
export const POLL_TIME = 5_000; // 5 seconds
|
||||
|
|
Loading…
Reference in New Issue