From 1506f51ff41d940dd38756e7cd89987aa801fbd8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 16 Jun 2023 21:53:49 -0500 Subject: [PATCH] fix(api): handle more SwinIR models --- api/onnx_web/convert/upscaling/swinir.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/upscaling/swinir.py b/api/onnx_web/convert/upscaling/swinir.py index fbeb0aef..36b6168c 100644 --- a/api/onnx_web/convert/upscaling/swinir.py +++ b/api/onnx_web/convert/upscaling/swinir.py @@ -86,8 +86,10 @@ def convert_upscaling_swinir( torch_model = torch.load(source, map_location=conversion.map_location) if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"], strict=False) - else: + elif "params" in torch_model: model.load_state_dict(torch_model["params"], strict=False) + else: + model.load_state_dict(torch_model, strict=False) model.to(conversion.training_device).train(False) model.eval()