fix(api): trim model names relative to model path
This commit is contained in:
parent
23a9d5afb7
commit
4472a6fd24
|
@ -203,17 +203,26 @@ def check_paths(context: ServerContext):
|
||||||
makedirs(context.output_path)
|
makedirs(context.output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name(model: str) -> str:
|
||||||
|
base = path.basename(model)
|
||||||
|
(file, _ext) = path.splitext(base)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
def load_models(context: ServerContext):
|
def load_models(context: ServerContext):
|
||||||
global diffusion_models
|
global diffusion_models
|
||||||
global correction_models
|
global correction_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
|
||||||
diffusion_models = glob(path.join(context.model_path, 'diffusion-*'))
|
diffusion_models = [get_model_name(f) for f in glob(
|
||||||
diffusion_models.extend(glob(path.join(context.model_path, 'stable-diffusion-*')))
|
path.join(context.model_path, 'diffusion-*'))]
|
||||||
|
diffusion_models.extend([
|
||||||
correction_models = glob(path.join(context.model_path, 'correction-*'))
|
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))])
|
||||||
upscaling_models = glob(path.join(context.model_path, 'upscaling-*'))
|
|
||||||
|
|
||||||
|
correction_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))]
|
||||||
|
upscaling_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))]
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext):
|
def load_params(context: ServerContext):
|
||||||
|
|
Loading…
Reference in New Issue