1
0
Fork 0

lint recent changes

This commit is contained in:
Sean Sube 2023-02-06 17:26:51 -06:00
parent 833fc5c2f8
commit 651acf6991
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 14 additions and 5 deletions

View File

@ -1,10 +1,10 @@
from logging import getLogger from logging import getLogger
from typing import Optional
import numpy as np import numpy as np
import torch import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image from PIL import Image
from typing import Optional
from ..device_pool import JobContext from ..device_pool import JobContext
from ..diffusion.load import load_pipeline from ..diffusion.load import load_pipeline

View File

@ -1,10 +1,10 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional
import numpy as np import numpy as np
from gfpgan import GFPGANer from gfpgan import GFPGANer
from PIL import Image from PIL import Image
from typing import Optional
from ..device_pool import JobContext from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
@ -18,7 +18,12 @@ last_pipeline_instance: Optional[GFPGANer] = None
last_pipeline_params: Optional[str] = None last_pipeline_params: Optional[str] = None
def load_gfpgan(server: ServerContext, stage: StageParams, upscale: UpscaleParams, device: DeviceParams): def load_gfpgan(
server: ServerContext,
stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params

View File

@ -35,7 +35,9 @@ def run_upscale_correction(
if upscale.scale > 1: if upscale.scale > 1:
if "esrgan" in upscale.upscale_model: if "esrgan" in upscale.upscale_model:
resr_stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale) resr_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale
)
chain.append((upscale_resrgan, resr_stage, None)) chain.append((upscale_resrgan, resr_stage, None))
elif "stable-diffusion" in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
@ -45,7 +47,9 @@ def run_upscale_correction(
logger.warn("unknown upscaling model: %s", upscale.upscale_model) logger.warn("unknown upscaling model: %s", upscale.upscale_model)
if upscale.faces: if upscale.faces:
face_stage = StageParams(tile_size=stage.tile_size, outscale=upscale.face_outscale) face_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model: if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, face_stage, None)) chain.append((correct_codeformer, face_stage, None))
elif "gfpgan" in upscale.correction_model: elif "gfpgan" in upscale.correction_model: