feat(api): add GFPGAN and Real ESRGAN to model cache
This commit is contained in:
parent
e9472bc005
commit
0709c1dbf0
|
@ -13,24 +13,19 @@ from ..utils import ServerContext, run_gc
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
last_pipeline_instance: Optional[GFPGANer] = None
|
|
||||||
last_pipeline_params: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_gfpgan(
|
def load_gfpgan(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
_stage: StageParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
device: DeviceParams,
|
_device: DeviceParams,
|
||||||
):
|
):
|
||||||
global last_pipeline_instance
|
|
||||||
global last_pipeline_params
|
|
||||||
|
|
||||||
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
|
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
|
||||||
|
cache_key = (face_path,)
|
||||||
|
cache_pipe = server.cache.get("gfpgan", cache_key)
|
||||||
|
|
||||||
if last_pipeline_instance is not None and face_path == last_pipeline_params:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing GFPGAN pipeline")
|
logger.info("reusing existing GFPGAN pipeline")
|
||||||
return last_pipeline_instance
|
return cache_pipe
|
||||||
|
|
||||||
logger.debug("loading GFPGAN model from %s", face_path)
|
logger.debug("loading GFPGAN model from %s", face_path)
|
||||||
|
|
||||||
|
@ -43,8 +38,7 @@ def load_gfpgan(
|
||||||
upscale=upscale.face_outscale,
|
upscale=upscale.face_outscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_pipeline_instance = gfpgan
|
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||||
last_pipeline_params = face_path
|
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
return gfpgan
|
return gfpgan
|
||||||
|
|
|
@ -21,21 +21,20 @@ x4_v3_tag = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(
|
def load_resrgan(
|
||||||
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||||
):
|
):
|
||||||
global last_pipeline_instance
|
|
||||||
global last_pipeline_params
|
|
||||||
|
|
||||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||||
model_path = path.join(ctx.model_path, model_file)
|
model_path = path.join(server.model_path, model_file)
|
||||||
|
|
||||||
|
cache_key = (model_path, params.format)
|
||||||
|
cache_pipe = server.cache.get("resrgan", cache_key)
|
||||||
|
if cache_pipe is not None:
|
||||||
|
logger.info("reusing existing Real ESRGAN pipeline")
|
||||||
|
return cache_pipe
|
||||||
|
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||||
|
|
||||||
cache_params = (model_path, params.format)
|
|
||||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
|
||||||
logger.info("reusing existing Real ESRGAN pipeline")
|
|
||||||
return last_pipeline_instance
|
|
||||||
|
|
||||||
if x4_v3_tag in model_file:
|
if x4_v3_tag in model_file:
|
||||||
# the x4-v3 model needs a different network
|
# the x4-v3 model needs a different network
|
||||||
model = SRVGGNetCompact(
|
model = SRVGGNetCompact(
|
||||||
|
@ -49,7 +48,7 @@ def load_resrgan(
|
||||||
elif params.format == "onnx":
|
elif params.format == "onnx":
|
||||||
# use ONNX acceleration, if available
|
# use ONNX acceleration, if available
|
||||||
model = OnnxNet(
|
model = OnnxNet(
|
||||||
ctx, model_file, provider=device.provider, provider_options=device.options
|
server, model_file, provider=device.provider, provider_options=device.options
|
||||||
)
|
)
|
||||||
elif params.format == "pth":
|
elif params.format == "pth":
|
||||||
model = RRDBNet(
|
model = RRDBNet(
|
||||||
|
@ -72,7 +71,7 @@ def load_resrgan(
|
||||||
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||||
|
|
||||||
# TODO: shouldn't need the PTH file
|
# TODO: shouldn't need the PTH file
|
||||||
model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model)
|
model_path_pth = path.join(server.model_path, "%s.pth" % params.upscale_model)
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
model_path=model_path_pth,
|
model_path=model_path_pth,
|
||||||
|
@ -84,8 +83,7 @@ def load_resrgan(
|
||||||
half=params.half,
|
half=params.half,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_pipeline_instance = upsampler
|
server.cache.set("resrgan", cache_key, upsampler)
|
||||||
last_pipeline_params = cache_params
|
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
|
@ -15,22 +15,18 @@ from ..utils import ServerContext, run_gc
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
last_pipeline_instance = None
|
|
||||||
last_pipeline_params = (None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def load_stable_diffusion(
|
def load_stable_diffusion(
|
||||||
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
server: ServerContext, upscale: UpscaleParams, device: DeviceParams
|
||||||
):
|
):
|
||||||
global last_pipeline_instance
|
model_path = path.join(server.model_path, upscale.upscale_model)
|
||||||
global last_pipeline_params
|
|
||||||
|
|
||||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
cache_key = (model_path, upscale.format)
|
||||||
cache_params = (model_path, upscale.format)
|
cache_pipe = server.cache.get("diffusion", cache_key)
|
||||||
|
|
||||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
if cache_pipe is not None:
|
||||||
logger.debug("reusing existing Stable Diffusion upscale pipeline")
|
logger.debug("reusing existing Stable Diffusion upscale pipeline")
|
||||||
return last_pipeline_instance
|
return cache_pipe
|
||||||
|
|
||||||
if upscale.format == "onnx":
|
if upscale.format == "onnx":
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -38,7 +34,7 @@ def load_stable_diffusion(
|
||||||
model_path,
|
model_path,
|
||||||
device.provider,
|
device.provider,
|
||||||
)
|
)
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
model_path, provider=device.provider, provider_options=device.options
|
model_path, provider=device.provider, provider_options=device.options
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -47,15 +43,14 @@ def load_stable_diffusion(
|
||||||
model_path,
|
model_path,
|
||||||
device.provider,
|
device.provider,
|
||||||
)
|
)
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
model_path, provider=device.provider
|
model_path, provider=device.provider
|
||||||
)
|
)
|
||||||
|
|
||||||
last_pipeline_instance = pipeline
|
server.cache.set("diffusion", cache_key, pipe)
|
||||||
last_pipeline_params = cache_params
|
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
return pipeline
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
def upscale_stable_diffusion(
|
def upscale_stable_diffusion(
|
||||||
|
|
|
@ -12,8 +12,9 @@ class ModelCache:
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
|
||||||
def drop(self, tag: str, key: Any) -> None:
|
def drop(self, tag: str, key: Any) -> None:
|
||||||
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
|
self.cache = [
|
||||||
|
model for model in self.cache if model[0] != tag and model[1] != key
|
||||||
|
]
|
||||||
|
|
||||||
def get(self, tag: str, key: Any) -> Any:
|
def get(self, tag: str, key: Any) -> Any:
|
||||||
for t, k, v in self.cache:
|
for t, k, v in self.cache:
|
||||||
|
@ -38,7 +39,9 @@ class ModelCache:
|
||||||
def prune(self):
|
def prune(self):
|
||||||
total = len(self.cache)
|
total = len(self.cache)
|
||||||
if total > self.limit:
|
if total > self.limit:
|
||||||
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
|
logger.info(
|
||||||
|
"Removing models from cache, %s of %s", (total - self.limit), total
|
||||||
|
)
|
||||||
self.cache[:] = self.cache[: self.limit]
|
self.cache[:] = self.cache[: self.limit]
|
||||||
else:
|
else:
|
||||||
logger.debug("Model cache below limit, %s of %s", total, self.limit)
|
logger.debug("Model cache below limit, %s of %s", total, self.limit)
|
||||||
|
|
Loading…
Reference in New Issue