1
0
Fork 0

fix(api): handle more SwinIR models

This commit is contained in:
Sean Sube 2023-06-16 21:53:49 -05:00
parent 12e489b761
commit 1506f51ff4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 1 deletions

View File

@ -86,8 +86,10 @@ def convert_upscaling_swinir(
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
model.load_state_dict(torch_model, strict=False)
model.to(conversion.training_device).train(False)
model.eval()