1
0
Fork 0

fix(api): include model scale

This commit is contained in:
Sean Sube 2023-01-16 20:10:29 -06:00
parent 556d5b84d6
commit dba6113c09
1 changed files with 10 additions and 8 deletions

View File

@ -16,28 +16,30 @@ sources: Dict[str, List[Tuple[str, str]]] = {
'diffusers': [
# v1.x
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
('stable-diffusion-onnx-v1-inpainting', 'runwayml/stable-diffusion-inpainting'),
('stable-diffusion-onnx-v1-inpainting',
'runwayml/stable-diffusion-inpainting'),
# v2.x
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'),
('stable-diffusion-onnx-v2-inpainting',
'stabilityai/stable-diffusion-2-inpainting'),
],
'gfpgan': [
('GFPGANv1.3', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
('correction-gfpgan-v1-3',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
],
'real_esrgan': [
('RealESRGAN_x4plus', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'),
('upscaling-real-esrgan-x4-plus',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
],
}
model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.no_grad()
def convert_real_esrgan(name: str, url: str, opset: int):
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name)
dest_onnx = path.join(model_path, name + '.onnx')
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
@ -53,7 +55,7 @@ def convert_real_esrgan(name: str, url: str, opset: int):
print('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
num_block=23, num_grow_ch=32, scale=scale)
model.load_state_dict(torch.load(dest_path)['params_ema'])
model.to(training_device).train(False)
model.eval()