feat(api): add GFPGAN and Real ESRGAN to model cache
This commit is contained in:
parent
e9472bc005
commit
0709c1dbf0
|
@ -13,24 +13,19 @@ from ..utils import ServerContext, run_gc
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
last_pipeline_instance: Optional[GFPGANer] = None
|
||||
last_pipeline_params: Optional[str] = None
|
||||
|
||||
|
||||
def load_gfpgan(
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_stage: StageParams,
|
||||
upscale: UpscaleParams,
|
||||
device: DeviceParams,
|
||||
_device: DeviceParams,
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
|
||||
cache_key = (face_path,)
|
||||
cache_pipe = server.cache.get("gfpgan", cache_key)
|
||||
|
||||
if last_pipeline_instance is not None and face_path == last_pipeline_params:
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing GFPGAN pipeline")
|
||||
return last_pipeline_instance
|
||||
return cache_pipe
|
||||
|
||||
logger.debug("loading GFPGAN model from %s", face_path)
|
||||
|
||||
|
@ -43,8 +38,7 @@ def load_gfpgan(
|
|||
upscale=upscale.face_outscale,
|
||||
)
|
||||
|
||||
last_pipeline_instance = gfpgan
|
||||
last_pipeline_params = face_path
|
||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||
run_gc()
|
||||
|
||||
return gfpgan
|
||||
|
|
|
@ -21,21 +21,20 @@ x4_v3_tag = "real-esrgan-x4-v3"
|
|||
|
||||
|
||||
def load_resrgan(
|
||||
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
model_path = path.join(server.model_path, model_file)
|
||||
|
||||
cache_key = (model_path, params.format)
|
||||
cache_pipe = server.cache.get("resrgan", cache_key)
|
||||
if cache_pipe is not None:
|
||||
logger.info("reusing existing Real ESRGAN pipeline")
|
||||
return cache_pipe
|
||||
|
||||
if not path.isfile(model_path):
|
||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||
|
||||
cache_params = (model_path, params.format)
|
||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||
logger.info("reusing existing Real ESRGAN pipeline")
|
||||
return last_pipeline_instance
|
||||
|
||||
if x4_v3_tag in model_file:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
|
@ -49,7 +48,7 @@ def load_resrgan(
|
|||
elif params.format == "onnx":
|
||||
# use ONNX acceleration, if available
|
||||
model = OnnxNet(
|
||||
ctx, model_file, provider=device.provider, provider_options=device.options
|
||||
server, model_file, provider=device.provider, provider_options=device.options
|
||||
)
|
||||
elif params.format == "pth":
|
||||
model = RRDBNet(
|
||||
|
@ -72,7 +71,7 @@ def load_resrgan(
|
|||
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||
|
||||
# TODO: shouldn't need the PTH file
|
||||
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
|
||||
model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model)
|
||||
upsampler = RealESRGANer(
|
||||
scale=params.scale,
|
||||
model_path=model_path_pth,
|
||||
|
@ -84,8 +83,7 @@ def load_resrgan(
|
|||
half=params.half,
|
||||
)
|
||||
|
||||
last_pipeline_instance = upsampler
|
||||
last_pipeline_params = cache_params
|
||||
server.cache.set("resrgan", cache_key, upsampler)
|
||||
run_gc()
|
||||
|
||||
return upsampler
|
||||
|
|
|
@ -15,22 +15,18 @@ from ..utils import ServerContext, run_gc
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = (None, None)
|
||||
|
||||
|
||||
def load_stable_diffusion(
|
||||
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
||||
server: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
model_path = path.join(server.model_path, upscale.upscale_model)
|
||||
|
||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||
cache_params = (model_path, upscale.format)
|
||||
cache_key = (model_path, upscale.format)
|
||||
cache_pipe = server.cache.get("diffusion", cache_key)
|
||||
|
||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||
if cache_pipe is not None:
|
||||
logger.debug("reusing existing Stable Diffusion upscale pipeline")
|
||||
return last_pipeline_instance
|
||||
return cache_pipe
|
||||
|
||||
if upscale.format == "onnx":
|
||||
logger.debug(
|
||||
|
@ -38,7 +34,7 @@ def load_stable_diffusion(
|
|||
model_path,
|
||||
device.provider,
|
||||
)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path, provider=device.provider, provider_options=device.options
|
||||
)
|
||||
else:
|
||||
|
@ -47,15 +43,14 @@ def load_stable_diffusion(
|
|||
model_path,
|
||||
device.provider,
|
||||
)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path, provider=device.provider
|
||||
)
|
||||
|
||||
last_pipeline_instance = pipeline
|
||||
last_pipeline_params = cache_params
|
||||
server.cache.set("diffusion", cache_key, pipe)
|
||||
run_gc()
|
||||
|
||||
return pipeline
|
||||
return pipe
|
||||
|
||||
|
||||
def upscale_stable_diffusion(
|
||||
|
|
|
@ -12,8 +12,9 @@ class ModelCache:
|
|||
self.limit = limit
|
||||
|
||||
def drop(self, tag: str, key: Any) -> None:
|
||||
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
|
||||
|
||||
self.cache = [
|
||||
model for model in self.cache if model[0] != tag and model[1] != key
|
||||
]
|
||||
|
||||
def get(self, tag: str, key: Any) -> Any:
|
||||
for t, k, v in self.cache:
|
||||
|
@ -38,7 +39,9 @@ class ModelCache:
|
|||
def prune(self):
|
||||
total = len(self.cache)
|
||||
if total > self.limit:
|
||||
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
|
||||
logger.info(
|
||||
"Removing models from cache, %s of %s", (total - self.limit), total
|
||||
)
|
||||
self.cache[:] = self.cache[: self.limit]
|
||||
else:
|
||||
logger.debug("Model cache below limit, %s of %s", total, self.limit)
|
||||
|
|
Loading…
Reference in New Issue