1
0
Fork 0

feat(api): support CPU model offloading specifically for SDXL

This commit is contained in:
Sean Sube 2024-01-12 18:58:26 -06:00
parent 853b92e88b
commit f5506b17f0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 61 additions and 22 deletions

View File

@ -1,10 +1,10 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..output import save_image, save_result
from ..params import ImageParams, Size, SizeChart, StageParams
from ..output import save_result
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .base import BaseStage

View File

@ -229,7 +229,7 @@ class ImageMetadata:
if self.models is not None:
for name, weight in self.models:
name, hash = self.get_model_hash()
name, hash = self.get_model_hash(server)
json["models"].append({"name": name, "weight": weight, "hash": hash})
return json

View File

@ -47,7 +47,7 @@ class UpscaleBSRGANStage(BaseStage):
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
provider=device.ort_provider("bsrgan"),
sess_options=device.sess_options(),
)

View File

@ -64,7 +64,7 @@ class UpscaleRealESRGANStage(BaseStage):
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
provider=device.ort_provider("esrgan"),
sess_options=device.sess_options(),
)

View File

@ -40,7 +40,7 @@ class UpscaleSwinIRStage(BaseStage):
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
provider=device.ort_provider("swinir"),
sess_options=device.sess_options(),
)

View File

@ -158,7 +158,7 @@ def load_pipeline(
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.ort_provider(),
provider=device.ort_provider("scheduler"),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
@ -179,7 +179,7 @@ def load_pipeline(
logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained(
model,
provider=device.ort_provider(),
provider=device.ort_provider("scheduler"),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
@ -320,7 +320,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
provider=device.ort_provider(),
provider=device.ort_provider("controlnet"),
sess_options=device.sess_options(),
)
)
@ -448,7 +448,7 @@ def load_text_encoders(
# session for te1
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
providers=[device.ort_provider("text-encoder", "sdxl")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
@ -457,7 +457,7 @@ def load_text_encoders(
# session for te2
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
providers=[device.ort_provider("text-encoder", "sdxl")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
@ -511,7 +511,7 @@ def load_unet(
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
providers=[device.ort_provider("unet", "sdxl")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
@ -551,7 +551,7 @@ def load_vae(
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
provider=device.ort_provider("vae", "sdxl"),
sess_options=device.sess_options(),
)
components["vae_decoder_session"]._model_path = vae_decoder
@ -559,7 +559,7 @@ def load_vae(
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
provider=device.ort_provider("vae", "sdxl"),
sess_options=device.sess_options(),
)
components["vae_encoder_session"]._model_path = vae_encoder

View File

@ -84,7 +84,7 @@ def save_image(
server: ServerContext,
output: str,
image: Image.Image,
metadata: ImageMetadata,
metadata: Optional[ImageMetadata] = None,
) -> str:
path = base_join(server.output_path, output)

View File

@ -136,12 +136,21 @@ class DeviceParams:
return "%s - %s (%s)" % (self.device, self.provider, self.options)
def ort_provider(
self, model_type: Optional[str] = None
self,
model_type: str,
suffix: Optional[str] = None,
) -> Union[str, Tuple[str, Any]]:
if model_type is not None:
# check if model has been pinned to CPU
# TODO: check whether the CPU device is allowed
if f"onnx-cpu-{model_type}" in self.optimizations:
logger.debug("pinning %s to CPU", model_type)
return "CPUExecutionProvider"
if (
suffix is not None
and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations
):
logger.debug("pinning %s-%s to CPU", model_type, suffix)
return "CPUExecutionProvider"
if self.options is None:

View File

@ -0,0 +1,30 @@
from typing import List, Optional
class NetworkWeight:
pass
class PromptRegion:
pass
class PromptSeed:
pass
class StructuredPrompt:
prompt: str
negative_prompt: Optional[str]
networks: List[NetworkWeight]
region_prompts: List[PromptRegion]
region_seeds: List[PromptSeed]
def __init__(
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
) -> None:
self.prompt = prompt
self.negative_prompt = negative_prompt
self.networks = networks or []
self.region_prompts = []
self.region_seeds = []