1
0
Fork 0

feat(api): add GFPGAN and Real ESRGAN to model cache

This commit is contained in:
Sean Sube 2023-02-13 18:10:11 -06:00
parent e9472bc005
commit 0709c1dbf0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 35 additions and 45 deletions

View File

@ -13,24 +13,19 @@ from ..utils import ServerContext, run_gc
logger = getLogger(__name__)
last_pipeline_instance: Optional[GFPGANer] = None
last_pipeline_params: Optional[str] = None
def load_gfpgan(
server: ServerContext,
stage: StageParams,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
_device: DeviceParams,
):
global last_pipeline_instance
global last_pipeline_params
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
if last_pipeline_instance is not None and face_path == last_pipeline_params:
if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance
return cache_pipe
logger.debug("loading GFPGAN model from %s", face_path)
@ -43,8 +38,7 @@ def load_gfpgan(
upscale=upscale.face_outscale,
)
last_pipeline_instance = gfpgan
last_pipeline_params = face_path
server.cache.set("gfpgan", cache_key, gfpgan)
run_gc()
return gfpgan

View File

@ -21,21 +21,20 @@ x4_v3_tag = "real-esrgan-x4-v3"
def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
global last_pipeline_instance
global last_pipeline_params
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)
cache_params = (model_path, params.format)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance
if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
@ -49,7 +48,7 @@ def load_resrgan(
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
ctx, model_file, provider=device.provider, provider_options=device.options
server, model_file, provider=device.provider, provider_options=device.options
)
elif params.format == "pth":
model = RRDBNet(
@ -72,7 +71,7 @@ def load_resrgan(
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model)
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path_pth,
@ -84,8 +83,7 @@ def load_resrgan(
half=params.half,
)
last_pipeline_instance = upsampler
last_pipeline_params = cache_params
server.cache.set("resrgan", cache_key, upsampler)
run_gc()
return upsampler

View File

@ -15,22 +15,18 @@ from ..utils import ServerContext, run_gc
logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_stable_diffusion(
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
server: ServerContext, upscale: UpscaleParams, device: DeviceParams
):
global last_pipeline_instance
global last_pipeline_params
model_path = path.join(server.model_path, upscale.upscale_model)
model_path = path.join(ctx.model_path, upscale.upscale_model)
cache_params = (model_path, upscale.format)
cache_key = (model_path, upscale.format)
cache_pipe = server.cache.get("diffusion", cache_key)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
if cache_pipe is not None:
logger.debug("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance
return cache_pipe
if upscale.format == "onnx":
logger.debug(
@ -38,7 +34,7 @@ def load_stable_diffusion(
model_path,
device.provider,
)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, provider_options=device.options
)
else:
@ -47,15 +43,14 @@ def load_stable_diffusion(
model_path,
device.provider,
)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider
)
last_pipeline_instance = pipeline
last_pipeline_params = cache_params
server.cache.set("diffusion", cache_key, pipe)
run_gc()
return pipeline
return pipe
def upscale_stable_diffusion(

View File

@ -12,8 +12,9 @@ class ModelCache:
self.limit = limit
def drop(self, tag: str, key: Any) -> None:
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
self.cache = [
model for model in self.cache if model[0] != tag and model[1] != key
]
def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache:
@ -38,7 +39,9 @@ class ModelCache:
def prune(self):
total = len(self.cache)
if total > self.limit:
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
logger.info(
"Removing models from cache, %s of %s", (total - self.limit), total
)
self.cache[:] = self.cache[: self.limit]
else:
logger.debug("Model cache below limit, %s of %s", total, self.limit)