1
0
Fork 0

feat(api): support CPU model offloading specifically for SDXL

This commit is contained in:
Sean Sube 2024-01-12 18:58:26 -06:00
parent 853b92e88b
commit f5506b17f0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 61 additions and 22 deletions

View File

@ -1,10 +1,10 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..output import save_image, save_result from ..output import save_result
from ..params import ImageParams, Size, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .base import BaseStage from .base import BaseStage

View File

@ -229,7 +229,7 @@ class ImageMetadata:
if self.models is not None: if self.models is not None:
for name, weight in self.models: for name, weight in self.models:
name, hash = self.get_model_hash() name, hash = self.get_model_hash(server)
json["models"].append({"name": name, "weight": weight, "hash": hash}) json["models"].append({"name": name, "weight": weight, "hash": hash})
return json return json

View File

@ -47,7 +47,7 @@ class UpscaleBSRGANStage(BaseStage):
pipe = OnnxModel( pipe = OnnxModel(
server, server,
model_path, model_path,
provider=device.ort_provider(), provider=device.ort_provider("bsrgan"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )

View File

@ -64,7 +64,7 @@ class UpscaleRealESRGANStage(BaseStage):
model = OnnxRRDBNet( model = OnnxRRDBNet(
server, server,
model_file, model_file,
provider=device.ort_provider(), provider=device.ort_provider("esrgan"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )

View File

@ -40,7 +40,7 @@ class UpscaleSwinIRStage(BaseStage):
pipe = OnnxModel( pipe = OnnxModel(
server, server,
model_path, model_path,
provider=device.ort_provider(), provider=device.ort_provider("swinir"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )

View File

@ -158,7 +158,7 @@ def load_pipeline(
logger.debug("loading new diffusion scheduler") logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.ort_provider(), provider=device.ort_provider("scheduler"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -179,7 +179,7 @@ def load_pipeline(
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.ort_provider(), provider=device.ort_provider("scheduler"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
@ -320,7 +320,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa
components["controlnet"] = OnnxRuntimeModel( components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
cnet_path, cnet_path,
provider=device.ort_provider(), provider=device.ort_provider("controlnet"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
) )
@ -448,7 +448,7 @@ def load_text_encoders(
# session for te1 # session for te1
text_encoder_session = InferenceSession( text_encoder_session = InferenceSession(
text_encoder.SerializeToString(), text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder", "sdxl")],
sess_options=text_encoder_opts, sess_options=text_encoder_opts,
) )
text_encoder_session._model_path = path.join(model, "text_encoder") text_encoder_session._model_path = path.join(model, "text_encoder")
@ -457,7 +457,7 @@ def load_text_encoders(
# session for te2 # session for te2
text_encoder_2_session = InferenceSession( text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(), text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder", "sdxl")],
sess_options=text_encoder_2_opts, sess_options=text_encoder_2_opts,
) )
text_encoder_2_session._model_path = path.join(model, "text_encoder_2") text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
@ -511,7 +511,7 @@ def load_unet(
if params.is_xl(): if params.is_xl():
unet_session = InferenceSession( unet_session = InferenceSession(
unet_model.SerializeToString(), unet_model.SerializeToString(),
providers=[device.ort_provider("unet")], providers=[device.ort_provider("unet", "sdxl")],
sess_options=unet_opts, sess_options=unet_opts,
) )
unet_session._model_path = path.join(model, "unet") unet_session._model_path = path.join(model, "unet")
@ -551,7 +551,7 @@ def load_vae(
logger.debug("loading VAE decoder from %s", vae_decoder) logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model( components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder, vae_decoder,
provider=device.ort_provider("vae"), provider=device.ort_provider("vae", "sdxl"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components["vae_decoder_session"]._model_path = vae_decoder components["vae_decoder_session"]._model_path = vae_decoder
@ -559,7 +559,7 @@ def load_vae(
logger.debug("loading VAE encoder from %s", vae_encoder) logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model( components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder, vae_encoder,
provider=device.ort_provider("vae"), provider=device.ort_provider("vae", "sdxl"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components["vae_encoder_session"]._model_path = vae_encoder components["vae_encoder_session"]._model_path = vae_encoder

View File

@ -84,7 +84,7 @@ def save_image(
server: ServerContext, server: ServerContext,
output: str, output: str,
image: Image.Image, image: Image.Image,
metadata: ImageMetadata, metadata: Optional[ImageMetadata] = None,
) -> str: ) -> str:
path = base_join(server.output_path, output) path = base_join(server.output_path, output)

View File

@ -136,13 +136,22 @@ class DeviceParams:
return "%s - %s (%s)" % (self.device, self.provider, self.options) return "%s - %s (%s)" % (self.device, self.provider, self.options)
def ort_provider( def ort_provider(
self, model_type: Optional[str] = None self,
model_type: str,
suffix: Optional[str] = None,
) -> Union[str, Tuple[str, Any]]: ) -> Union[str, Tuple[str, Any]]:
if model_type is not None: # check if model has been pinned to CPU
# check if model has been pinned to CPU # TODO: check whether the CPU device is allowed
# TODO: check whether the CPU device is allowed if f"onnx-cpu-{model_type}" in self.optimizations:
if f"onnx-cpu-{model_type}" in self.optimizations: logger.debug("pinning %s to CPU", model_type)
return "CPUExecutionProvider" return "CPUExecutionProvider"
if (
suffix is not None
and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations
):
logger.debug("pinning %s-%s to CPU", model_type, suffix)
return "CPUExecutionProvider"
if self.options is None: if self.options is None:
return self.provider return self.provider

View File

@ -0,0 +1,30 @@
from typing import List, Optional
class NetworkWeight:
pass
class PromptRegion:
pass
class PromptSeed:
pass
class StructuredPrompt:
prompt: str
negative_prompt: Optional[str]
networks: List[NetworkWeight]
region_prompts: List[PromptRegion]
region_seeds: List[PromptSeed]
def __init__(
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
) -> None:
self.prompt = prompt
self.negative_prompt = negative_prompt
self.networks = networks or []
self.region_prompts = []
self.region_seeds = []