feat(api): support CPU model offloading specifically for SDXL
This commit is contained in:
parent
853b92e88b
commit
f5506b17f0
|
@ -1,10 +1,10 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..output import save_image, save_result
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..output import save_result
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .base import BaseStage
|
||||
|
|
|
@ -229,7 +229,7 @@ class ImageMetadata:
|
|||
|
||||
if self.models is not None:
|
||||
for name, weight in self.models:
|
||||
name, hash = self.get_model_hash()
|
||||
name, hash = self.get_model_hash(server)
|
||||
json["models"].append({"name": name, "weight": weight, "hash": hash})
|
||||
|
||||
return json
|
||||
|
|
|
@ -47,7 +47,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
pipe = OnnxModel(
|
||||
server,
|
||||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("bsrgan"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
model = OnnxRRDBNet(
|
||||
server,
|
||||
model_file,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("esrgan"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
pipe = OnnxModel(
|
||||
server,
|
||||
model_path,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("swinir"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ def load_pipeline(
|
|||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("scheduler"),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
|
@ -179,7 +179,7 @@ def load_pipeline(
|
|||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("scheduler"),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
|
@ -320,7 +320,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa
|
|||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
cnet_path,
|
||||
provider=device.ort_provider(),
|
||||
provider=device.ort_provider("controlnet"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
@ -448,7 +448,7 @@ def load_text_encoders(
|
|||
# session for te1
|
||||
text_encoder_session = InferenceSession(
|
||||
text_encoder.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
providers=[device.ort_provider("text-encoder", "sdxl")],
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||
|
@ -457,7 +457,7 @@ def load_text_encoders(
|
|||
# session for te2
|
||||
text_encoder_2_session = InferenceSession(
|
||||
text_encoder_2.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
providers=[device.ort_provider("text-encoder", "sdxl")],
|
||||
sess_options=text_encoder_2_opts,
|
||||
)
|
||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||
|
@ -511,7 +511,7 @@ def load_unet(
|
|||
if params.is_xl():
|
||||
unet_session = InferenceSession(
|
||||
unet_model.SerializeToString(),
|
||||
providers=[device.ort_provider("unet")],
|
||||
providers=[device.ort_provider("unet", "sdxl")],
|
||||
sess_options=unet_opts,
|
||||
)
|
||||
unet_session._model_path = path.join(model, "unet")
|
||||
|
@ -551,7 +551,7 @@ def load_vae(
|
|||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_decoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
provider=device.ort_provider("vae", "sdxl"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_decoder_session"]._model_path = vae_decoder
|
||||
|
@ -559,7 +559,7 @@ def load_vae(
|
|||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_encoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
provider=device.ort_provider("vae", "sdxl"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_encoder_session"]._model_path = vae_encoder
|
||||
|
|
|
@ -84,7 +84,7 @@ def save_image(
|
|||
server: ServerContext,
|
||||
output: str,
|
||||
image: Image.Image,
|
||||
metadata: ImageMetadata,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
) -> str:
|
||||
path = base_join(server.output_path, output)
|
||||
|
||||
|
|
|
@ -136,13 +136,22 @@ class DeviceParams:
|
|||
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||
|
||||
def ort_provider(
|
||||
self, model_type: Optional[str] = None
|
||||
self,
|
||||
model_type: str,
|
||||
suffix: Optional[str] = None,
|
||||
) -> Union[str, Tuple[str, Any]]:
|
||||
if model_type is not None:
|
||||
# check if model has been pinned to CPU
|
||||
# TODO: check whether the CPU device is allowed
|
||||
if f"onnx-cpu-{model_type}" in self.optimizations:
|
||||
return "CPUExecutionProvider"
|
||||
# check if model has been pinned to CPU
|
||||
# TODO: check whether the CPU device is allowed
|
||||
if f"onnx-cpu-{model_type}" in self.optimizations:
|
||||
logger.debug("pinning %s to CPU", model_type)
|
||||
return "CPUExecutionProvider"
|
||||
|
||||
if (
|
||||
suffix is not None
|
||||
and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations
|
||||
):
|
||||
logger.debug("pinning %s-%s to CPU", model_type, suffix)
|
||||
return "CPUExecutionProvider"
|
||||
|
||||
if self.options is None:
|
||||
return self.provider
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from typing import List, Optional
|
||||
|
||||
|
||||
class NetworkWeight:
|
||||
pass
|
||||
|
||||
|
||||
class PromptRegion:
|
||||
pass
|
||||
|
||||
|
||||
class PromptSeed:
|
||||
pass
|
||||
|
||||
|
||||
class StructuredPrompt:
|
||||
prompt: str
|
||||
negative_prompt: Optional[str]
|
||||
networks: List[NetworkWeight]
|
||||
region_prompts: List[PromptRegion]
|
||||
region_seeds: List[PromptSeed]
|
||||
|
||||
def __init__(
|
||||
self, prompt: str, negative_prompt: Optional[str], networks: List[NetworkWeight]
|
||||
) -> None:
|
||||
self.prompt = prompt
|
||||
self.negative_prompt = negative_prompt
|
||||
self.networks = networks or []
|
||||
self.region_prompts = []
|
||||
self.region_seeds = []
|
Loading…
Reference in New Issue