feat(api): support CPU model offloading specifically for SDXL
This commit is contained in:
parent
853b92e88b
commit
f5506b17f0
|
@ -1,10 +1,10 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image, save_result
|
from ..output import save_result
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
|
|
|
@ -229,7 +229,7 @@ class ImageMetadata:
|
||||||
|
|
||||||
if self.models is not None:
|
if self.models is not None:
|
||||||
for name, weight in self.models:
|
for name, weight in self.models:
|
||||||
name, hash = self.get_model_hash()
|
name, hash = self.get_model_hash(server)
|
||||||
json["models"].append({"name": name, "weight": weight, "hash": hash})
|
json["models"].append({"name": name, "weight": weight, "hash": hash})
|
||||||
|
|
||||||
return json
|
return json
|
||||||
|
|
|
@ -47,7 +47,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
pipe = OnnxModel(
|
pipe = OnnxModel(
|
||||||
server,
|
server,
|
||||||
model_path,
|
model_path,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("bsrgan"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
model = OnnxRRDBNet(
|
model = OnnxRRDBNet(
|
||||||
server,
|
server,
|
||||||
model_file,
|
model_file,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("esrgan"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
pipe = OnnxModel(
|
pipe = OnnxModel(
|
||||||
server,
|
server,
|
||||||
model_path,
|
model_path,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("swinir"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ def load_pipeline(
|
||||||
logger.debug("loading new diffusion scheduler")
|
logger.debug("loading new diffusion scheduler")
|
||||||
scheduler = scheduler_type.from_pretrained(
|
scheduler = scheduler_type.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("scheduler"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
@ -179,7 +179,7 @@ def load_pipeline(
|
||||||
logger.debug("loading new diffusion pipeline from %s", model)
|
logger.debug("loading new diffusion pipeline from %s", model)
|
||||||
scheduler = scheduler_type.from_pretrained(
|
scheduler = scheduler_type.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("scheduler"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
@ -320,7 +320,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa
|
||||||
components["controlnet"] = OnnxRuntimeModel(
|
components["controlnet"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
cnet_path,
|
cnet_path,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider("controlnet"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -448,7 +448,7 @@ def load_text_encoders(
|
||||||
# session for te1
|
# session for te1
|
||||||
text_encoder_session = InferenceSession(
|
text_encoder_session = InferenceSession(
|
||||||
text_encoder.SerializeToString(),
|
text_encoder.SerializeToString(),
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder", "sdxl")],
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
@ -457,7 +457,7 @@ def load_text_encoders(
|
||||||
# session for te2
|
# session for te2
|
||||||
text_encoder_2_session = InferenceSession(
|
text_encoder_2_session = InferenceSession(
|
||||||
text_encoder_2.SerializeToString(),
|
text_encoder_2.SerializeToString(),
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder", "sdxl")],
|
||||||
sess_options=text_encoder_2_opts,
|
sess_options=text_encoder_2_opts,
|
||||||
)
|
)
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
@ -511,7 +511,7 @@ def load_unet(
|
||||||
if params.is_xl():
|
if params.is_xl():
|
||||||
unet_session = InferenceSession(
|
unet_session = InferenceSession(
|
||||||
unet_model.SerializeToString(),
|
unet_model.SerializeToString(),
|
||||||
providers=[device.ort_provider("unet")],
|
providers=[device.ort_provider("unet", "sdxl")],
|
||||||
sess_options=unet_opts,
|
sess_options=unet_opts,
|
||||||
)
|
)
|
||||||
unet_session._model_path = path.join(model, "unet")
|
unet_session._model_path = path.join(model, "unet")
|
||||||
|
@ -551,7 +551,7 @@ def load_vae(
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
vae_decoder,
|
vae_decoder,
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae", "sdxl"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components["vae_decoder_session"]._model_path = vae_decoder
|
components["vae_decoder_session"]._model_path = vae_decoder
|
||||||
|
@ -559,7 +559,7 @@ def load_vae(
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae", "sdxl"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components["vae_encoder_session"]._model_path = vae_encoder
|
components["vae_encoder_session"]._model_path = vae_encoder
|
||||||
|
|
|
@ -84,7 +84,7 @@ def save_image(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
output: str,
|
output: str,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
metadata: ImageMetadata,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(server.output_path, output)
|
path = base_join(server.output_path, output)
|
||||||
|
|
||||||
|
|
|
@ -136,13 +136,22 @@ class DeviceParams:
|
||||||
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||||
|
|
||||||
def ort_provider(
|
def ort_provider(
|
||||||
self, model_type: Optional[str] = None
|
self,
|
||||||
|
model_type: str,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
) -> Union[str, Tuple[str, Any]]:
|
) -> Union[str, Tuple[str, Any]]:
|
||||||
if model_type is not None:
|
# check if model has been pinned to CPU
|
||||||
# check if model has been pinned to CPU
|
# TODO: check whether the CPU device is allowed
|
||||||
# TODO: check whether the CPU device is allowed
|
if f"onnx-cpu-{model_type}" in self.optimizations:
|
||||||
if f"onnx-cpu-{model_type}" in self.optimizations:
|
logger.debug("pinning %s to CPU", model_type)
|
||||||
return "CPUExecutionProvider"
|
return "CPUExecutionProvider"
|
||||||
|
|
||||||
|
if (
|
||||||
|
suffix is not None
|
||||||
|
and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations
|
||||||
|
):
|
||||||
|
logger.debug("pinning %s-%s to CPU", model_type, suffix)
|
||||||
|
return "CPUExecutionProvider"
|
||||||
|
|
||||||
if self.options is None:
|
if self.options is None:
|
||||||
return self.provider
|
return self.provider
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkWeight:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRegion:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PromptSeed:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredPrompt:
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: Optional[str]
|
||||||
|
networks: List[NetworkWeight]
|
||||||
|
region_prompts: List[PromptRegion]
|
||||||
|
region_seeds: List[PromptSeed]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
|
||||||
|
) -> None:
|
||||||
|
self.prompt = prompt
|
||||||
|
self.negative_prompt = negative_prompt
|
||||||
|
self.networks = networks or []
|
||||||
|
self.region_prompts = []
|
||||||
|
self.region_seeds = []
|
Loading…
Reference in New Issue