1
0
Fork 0

feat(api): add GFPGAN and Real ESRGAN to model cache

This commit is contained in:
Sean Sube 2023-02-13 18:10:11 -06:00
parent e9472bc005
commit 0709c1dbf0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 35 additions and 45 deletions

View File

@ -13,24 +13,19 @@ from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
last_pipeline_instance: Optional[GFPGANer] = None
last_pipeline_params: Optional[str] = None
def load_gfpgan( def load_gfpgan(
server: ServerContext, server: ServerContext,
stage: StageParams, _stage: StageParams,
upscale: UpscaleParams, upscale: UpscaleParams,
device: DeviceParams, _device: DeviceParams,
): ):
global last_pipeline_instance
global last_pipeline_params
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model)) face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
if last_pipeline_instance is not None and face_path == last_pipeline_params: if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline") logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance return cache_pipe
logger.debug("loading GFPGAN model from %s", face_path) logger.debug("loading GFPGAN model from %s", face_path)
@ -43,8 +38,7 @@ def load_gfpgan(
upscale=upscale.face_outscale, upscale=upscale.face_outscale,
) )
last_pipeline_instance = gfpgan server.cache.set("gfpgan", cache_key, gfpgan)
last_pipeline_params = face_path
run_gc() run_gc()
return gfpgan return gfpgan

View File

@ -21,21 +21,20 @@ x4_v3_tag = "real-esrgan-x4-v3"
def load_resrgan( def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
): ):
global last_pipeline_instance
global last_pipeline_params
model_file = "%s.%s" % (params.upscale_model, params.format) model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file) model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
if not path.isfile(model_path): if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path) raise Exception("Real ESRGAN model not found at %s" % model_path)
cache_params = (model_path, params.format)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance
if x4_v3_tag in model_file: if x4_v3_tag in model_file:
# the x4-v3 model needs a different network # the x4-v3 model needs a different network
model = SRVGGNetCompact( model = SRVGGNetCompact(
@ -49,7 +48,7 @@ def load_resrgan(
elif params.format == "onnx": elif params.format == "onnx":
# use ONNX acceleration, if available # use ONNX acceleration, if available
model = OnnxNet( model = OnnxNet(
ctx, model_file, provider=device.provider, provider_options=device.options server, model_file, provider=device.provider, provider_options=device.options
) )
elif params.format == "pth": elif params.format == "pth":
model = RRDBNet( model = RRDBNet(
@ -72,7 +71,7 @@ def load_resrgan(
logger.debug("loading Real ESRGAN upscale model from %s", model_path) logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file # TODO: shouldn't need the PTH file
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model) model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model)
upsampler = RealESRGANer( upsampler = RealESRGANer(
scale=params.scale, scale=params.scale,
model_path=model_path_pth, model_path=model_path_pth,
@ -84,8 +83,7 @@ def load_resrgan(
half=params.half, half=params.half,
) )
last_pipeline_instance = upsampler server.cache.set("resrgan", cache_key, upsampler)
last_pipeline_params = cache_params
run_gc() run_gc()
return upsampler return upsampler

View File

@ -15,22 +15,18 @@ from ..utils import ServerContext, run_gc
logger = getLogger(__name__) logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_stable_diffusion( def load_stable_diffusion(
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams server: ServerContext, upscale: UpscaleParams, device: DeviceParams
): ):
global last_pipeline_instance model_path = path.join(server.model_path, upscale.upscale_model)
global last_pipeline_params
model_path = path.join(ctx.model_path, upscale.upscale_model) cache_key = (model_path, upscale.format)
cache_params = (model_path, upscale.format) cache_pipe = server.cache.get("diffusion", cache_key)
if last_pipeline_instance is not None and cache_params == last_pipeline_params: if cache_pipe is not None:
logger.debug("reusing existing Stable Diffusion upscale pipeline") logger.debug("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance return cache_pipe
if upscale.format == "onnx": if upscale.format == "onnx":
logger.debug( logger.debug(
@ -38,7 +34,7 @@ def load_stable_diffusion(
model_path, model_path,
device.provider, device.provider,
) )
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained( pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, provider_options=device.options model_path, provider=device.provider, provider_options=device.options
) )
else: else:
@ -47,15 +43,14 @@ def load_stable_diffusion(
model_path, model_path,
device.provider, device.provider,
) )
pipeline = StableDiffusionUpscalePipeline.from_pretrained( pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider model_path, provider=device.provider
) )
last_pipeline_instance = pipeline server.cache.set("diffusion", cache_key, pipe)
last_pipeline_params = cache_params
run_gc() run_gc()
return pipeline return pipe
def upscale_stable_diffusion( def upscale_stable_diffusion(

View File

@ -12,8 +12,9 @@ class ModelCache:
self.limit = limit self.limit = limit
def drop(self, tag: str, key: Any) -> None: def drop(self, tag: str, key: Any) -> None:
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key] self.cache = [
model for model in self.cache if model[0] != tag and model[1] != key
]
def get(self, tag: str, key: Any) -> Any: def get(self, tag: str, key: Any) -> Any:
for t, k, v in self.cache: for t, k, v in self.cache:
@ -38,7 +39,9 @@ class ModelCache:
def prune(self): def prune(self):
total = len(self.cache) total = len(self.cache)
if total > self.limit: if total > self.limit:
logger.info("Removing models from cache, %s of %s", (total - self.limit), total) logger.info(
"Removing models from cache, %s of %s", (total - self.limit), total
)
self.cache[:] = self.cache[: self.limit] self.cache[:] = self.cache[: self.limit]
else: else:
logger.debug("Model cache below limit, %s of %s", total, self.limit) logger.debug("Model cache below limit, %s of %s", total, self.limit)