fix(api): handle more SwinIR models
This commit is contained in:
parent
12e489b761
commit
1506f51ff4
|
@ -86,8 +86,10 @@ def convert_upscaling_swinir(
|
|||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"], strict=False)
|
||||
else:
|
||||
elif "params" in torch_model:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model, strict=False)
|
||||
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue