fix(api): trim model names relative to model path
This commit is contained in:
parent
23a9d5afb7
commit
4472a6fd24
|
@ -203,17 +203,26 @@ def check_paths(context: ServerContext):
|
|||
makedirs(context.output_path)
|
||||
|
||||
|
||||
def get_model_name(model: str) -> str:
|
||||
base = path.basename(model)
|
||||
(file, _ext) = path.splitext(base)
|
||||
return file
|
||||
|
||||
|
||||
def load_models(context: ServerContext):
|
||||
global diffusion_models
|
||||
global correction_models
|
||||
global upscaling_models
|
||||
|
||||
diffusion_models = glob(path.join(context.model_path, 'diffusion-*'))
|
||||
diffusion_models.extend(glob(path.join(context.model_path, 'stable-diffusion-*')))
|
||||
|
||||
correction_models = glob(path.join(context.model_path, 'correction-*'))
|
||||
upscaling_models = glob(path.join(context.model_path, 'upscaling-*'))
|
||||
diffusion_models = [get_model_name(f) for f in glob(
|
||||
path.join(context.model_path, 'diffusion-*'))]
|
||||
diffusion_models.extend([
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))])
|
||||
|
||||
correction_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))]
|
||||
upscaling_models = [
|
||||
get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))]
|
||||
|
||||
|
||||
def load_params(context: ServerContext):
|
||||
|
|
Loading…
Reference in New Issue