fix(api): include model scale
This commit is contained in:
parent
556d5b84d6
commit
dba6113c09
|
@ -16,28 +16,30 @@ sources: Dict[str, List[Tuple[str, str]]] = {
|
|||
'diffusers': [
|
||||
# v1.x
|
||||
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
||||
('stable-diffusion-onnx-v1-inpainting', 'runwayml/stable-diffusion-inpainting'),
|
||||
('stable-diffusion-onnx-v1-inpainting',
|
||||
'runwayml/stable-diffusion-inpainting'),
|
||||
# v2.x
|
||||
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
||||
('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'),
|
||||
('stable-diffusion-onnx-v2-inpainting',
|
||||
'stabilityai/stable-diffusion-2-inpainting'),
|
||||
],
|
||||
'gfpgan': [
|
||||
('GFPGANv1.3', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
|
||||
('correction-gfpgan-v1-3',
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
|
||||
],
|
||||
'real_esrgan': [
|
||||
('RealESRGAN_x4plus', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'),
|
||||
('upscaling-real-esrgan-x4-plus',
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
|
||||
],
|
||||
}
|
||||
|
||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models'))
|
||||
|
||||
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan(name: str, url: str, opset: int):
|
||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
|
||||
|
@ -53,7 +55,7 @@ def convert_real_esrgan(name: str, url: str, opset: int):
|
|||
|
||||
print('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
num_block=23, num_grow_ch=32, scale=scale)
|
||||
model.load_state_dict(torch.load(dest_path)['params_ema'])
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue