feat(api): add model cache for diffusion models
This commit is contained in:
parent
7fa1783be4
commit
e9472bc005
|
@ -19,7 +19,7 @@ from .image import (
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
|
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
|
||||||
from .upscale import run_upscale_correction
|
from .server.upscale import run_upscale_correction
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
base_join,
|
base_join,
|
||||||
|
|
|
@ -5,9 +5,9 @@ from typing import Any, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,9 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..diffusion.load import load_pipeline
|
from ..diffusion.load import load_pipeline
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
def blend_img2img(
|
def blend_img2img(
|
||||||
job: JobContext,
|
job: JobContext,
|
||||||
_server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -30,6 +30,7 @@ def blend_img2img(
|
||||||
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
|
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ def blend_inpaint(
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
|
|
@ -5,8 +5,8 @@ from PIL import Image
|
||||||
|
|
||||||
from onnx_web.output import save_image
|
from onnx_web.output import save_image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -6,8 +6,8 @@ import numpy as np
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext, run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -2,9 +2,9 @@ from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -4,8 +4,8 @@ from logging import getLogger
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -3,8 +3,8 @@ from typing import Callable
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -5,9 +5,9 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, load_pipeline
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -16,7 +16,7 @@ logger = getLogger(__name__)
|
||||||
def source_txt2img(
|
def source_txt2img(
|
||||||
job: JobContext,
|
job: JobContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
|
@ -35,6 +35,7 @@ def source_txt2img(
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_grid, process_tile_order
|
from .utils import process_tile_grid, process_tile_order
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ def upscale_outpaint(
|
||||||
|
|
||||||
latents = get_tile_latents(full_latents, dims)
|
latents = get_tile_latents(full_latents, dims)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
|
|
@ -7,9 +7,9 @@ from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..onnx import OnnxNet
|
from ..onnx import OnnxNet
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..server.device_pool import JobContext
|
||||||
from ..utils import ServerContext, run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..device_pool import JobContext, ProgressCallback
|
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from ..server.device_pool import JobContext, ProgressCallback
|
||||||
from ..utils import ServerContext, run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -19,20 +19,10 @@ from diffusers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..utils import run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
last_pipeline_instance: Any = None
|
|
||||||
last_pipeline_options: Tuple[
|
|
||||||
Optional[DiffusionPipeline],
|
|
||||||
Optional[str],
|
|
||||||
Optional[str],
|
|
||||||
Optional[str],
|
|
||||||
Optional[bool],
|
|
||||||
] = (None, None, None, None, None)
|
|
||||||
last_pipeline_scheduler: Any = None
|
|
||||||
|
|
||||||
latent_channels = 4
|
latent_channels = 4
|
||||||
latent_factor = 8
|
latent_factor = 8
|
||||||
|
|
||||||
|
@ -90,24 +80,42 @@ def get_tile_latents(
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
|
server: ServerContext,
|
||||||
pipeline: DiffusionPipeline,
|
pipeline: DiffusionPipeline,
|
||||||
model: str,
|
model: str,
|
||||||
scheduler_type: Any,
|
scheduler_type: Any,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
):
|
):
|
||||||
global last_pipeline_instance
|
pipe_key = (pipeline, model, device.device, device.provider, lpw)
|
||||||
global last_pipeline_scheduler
|
scheduler_key = (scheduler_type,)
|
||||||
global last_pipeline_options
|
|
||||||
|
|
||||||
options = (pipeline, model, device.device, device.provider, lpw)
|
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||||
if last_pipeline_instance is not None and last_pipeline_options == options:
|
|
||||||
|
if cache_pipe is not None:
|
||||||
logger.debug("reusing existing diffusion pipeline")
|
logger.debug("reusing existing diffusion pipeline")
|
||||||
pipe = last_pipeline_instance
|
pipe = cache_pipe
|
||||||
|
|
||||||
|
cache_scheduler = server.cache.get("scheduler", scheduler_key)
|
||||||
|
if cache_scheduler is None:
|
||||||
|
logger.debug("loading new diffusion scheduler")
|
||||||
|
scheduler = scheduler_type.from_pretrained(
|
||||||
|
model,
|
||||||
|
provider=device.provider,
|
||||||
|
provider_options=device.options,
|
||||||
|
subfolder="scheduler",
|
||||||
|
)
|
||||||
|
|
||||||
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
|
scheduler = scheduler.to(device.torch_device())
|
||||||
|
|
||||||
|
pipe.scheduler = scheduler
|
||||||
|
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||||
|
run_gc()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("unloading previous diffusion pipeline")
|
logger.debug("unloading previous diffusion pipeline")
|
||||||
last_pipeline_instance = None
|
server.cache.drop("diffusion", pipe_key)
|
||||||
last_pipeline_scheduler = None
|
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
if lpw:
|
if lpw:
|
||||||
|
@ -135,24 +143,7 @@ def load_pipeline(
|
||||||
if device is not None and hasattr(pipe, "to"):
|
if device is not None and hasattr(pipe, "to"):
|
||||||
pipe = pipe.to(device.torch_device())
|
pipe = pipe.to(device.torch_device())
|
||||||
|
|
||||||
last_pipeline_instance = pipe
|
server.cache.set("diffusion", pipe_key, pipe)
|
||||||
last_pipeline_options = options
|
server.cache.set("scheduler", scheduler_key, scheduler)
|
||||||
last_pipeline_scheduler = scheduler_type
|
|
||||||
|
|
||||||
if last_pipeline_scheduler != scheduler_type:
|
|
||||||
logger.debug("loading new diffusion scheduler")
|
|
||||||
scheduler = scheduler_type.from_pretrained(
|
|
||||||
model,
|
|
||||||
provider=device.provider,
|
|
||||||
provider_options=device.options,
|
|
||||||
subfolder="scheduler",
|
|
||||||
)
|
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, "to"):
|
|
||||||
scheduler = scheduler.to(device.torch_device())
|
|
||||||
|
|
||||||
pipe.scheduler = scheduler
|
|
||||||
last_pipeline_scheduler = scheduler_type
|
|
||||||
run_gc()
|
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
|
@ -10,10 +10,10 @@ from onnx_web.chain import blend_mask
|
||||||
from onnx_web.chain.base import ChainProgress
|
from onnx_web.chain.base import ChainProgress
|
||||||
|
|
||||||
from ..chain import upscale_outpaint
|
from ..chain import upscale_outpaint
|
||||||
from ..device_pool import JobContext
|
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
from ..params import Border, ImageParams, Size, StageParams
|
from ..params import Border, ImageParams, Size, StageParams
|
||||||
from ..upscale import UpscaleParams, run_upscale_correction
|
from ..server.device_pool import JobContext
|
||||||
|
from ..server.upscale import UpscaleParams, run_upscale_correction
|
||||||
from ..utils import ServerContext, run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ def run_txt2img_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
@ -97,6 +98,7 @@ def run_img2img_pipeline(
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
|
|
|
@ -31,7 +31,6 @@ from .chain import (
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .device_pool import DevicePoolExecutor
|
|
||||||
from .diffusion.load import pipeline_schedulers
|
from .diffusion.load import pipeline_schedulers
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
@ -40,7 +39,6 @@ from .diffusion.run import (
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from .hacks import apply_patches
|
|
||||||
from .image import ( # mask filters; noise sources
|
from .image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -62,6 +60,8 @@ from .params import (
|
||||||
TileOrder,
|
TileOrder,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
|
from .server.device_pool import DevicePoolExecutor
|
||||||
|
from .server.hacks import apply_patches
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
base_join,
|
base_join,
|
||||||
|
|
|
@ -5,8 +5,8 @@ from multiprocessing import Value
|
||||||
from traceback import format_exception
|
from traceback import format_exception
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from .utils import run_gc
|
from ..utils import run_gc
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from urllib.parse import urlparse
|
||||||
import basicsr.utils.download_util
|
import basicsr.utils.download_util
|
||||||
import codeformer.facelib.utils.misc
|
import codeformer.facelib.utils.misc
|
||||||
|
|
||||||
from .utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCache:
|
||||||
|
cache: List[(str, Any, Any)]
|
||||||
|
limit: int
|
||||||
|
|
||||||
|
def __init__(self, limit: int) -> None:
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
def drop(self, tag: str, key: Any) -> None:
|
||||||
|
self.cache = [model for model in self.cache if model[0] != tag and model[1] != key]
|
||||||
|
|
||||||
|
|
||||||
|
def get(self, tag: str, key: Any) -> Any:
|
||||||
|
for t, k, v in self.cache:
|
||||||
|
if tag == t and key == k:
|
||||||
|
return v
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, tag: str, key: Any, value: Any) -> None:
|
||||||
|
for i in range(len(self.cache)):
|
||||||
|
t, k, v = self.cache[i]
|
||||||
|
if tag == t:
|
||||||
|
if key != k:
|
||||||
|
logger.debug("Updating model cache: %s", tag)
|
||||||
|
self.cache[i] = v
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("Adding new model to cache: %s", tag)
|
||||||
|
self.cache.append((tag, key, value))
|
||||||
|
self.prune()
|
||||||
|
|
||||||
|
def prune(self):
|
||||||
|
total = len(self.cache)
|
||||||
|
if total > self.limit:
|
||||||
|
logger.info("Removing models from cache, %s of %s", (total - self.limit), total)
|
||||||
|
self.cache[:] = self.cache[: self.limit]
|
||||||
|
else:
|
||||||
|
logger.debug("Model cache below limit, %s of %s", total, self.limit)
|
|
@ -2,16 +2,16 @@ from logging import getLogger
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .chain import (
|
from ..chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
correct_codeformer,
|
correct_codeformer,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
|
from ..utils import ServerContext
|
||||||
from .device_pool import JobContext, ProgressCallback
|
from .device_pool import JobContext, ProgressCallback
|
||||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
|
||||||
from .utils import ServerContext
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .params import SizeChart
|
from .params import SizeChart
|
||||||
|
from .server.model_cache import ModelCache
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -23,6 +24,7 @@ class ServerContext:
|
||||||
block_platforms: List[str] = [],
|
block_platforms: List[str] = [],
|
||||||
default_platform: str = None,
|
default_platform: str = None,
|
||||||
image_format: str = "png",
|
image_format: str = "png",
|
||||||
|
cache: ModelCache = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -34,6 +36,7 @@ class ServerContext:
|
||||||
self.block_platforms = block_platforms
|
self.block_platforms = block_platforms
|
||||||
self.default_platform = default_platform
|
self.default_platform = default_platform
|
||||||
self.image_format = image_format
|
self.image_format = image_format
|
||||||
|
self.cache = cache or ModelCache()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -51,6 +54,7 @@ class ServerContext:
|
||||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
|
cache=ModelCache(limit=int(environ.get("ONNX_WEB_CACHE_MODELS", 3))),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue