feat(api): backend support for multiple GPUs in diffusion pipelines
This commit is contained in:
parent
88c5113e37
commit
a868c8cf6b
|
@ -6,7 +6,7 @@ from diffusers import (
|
|||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from PIL import Image, ImageChops
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
|
@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
|||
return image_latents
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
safety_checker=None,
|
||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
pipe = pipe.to(device)
|
||||
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
print('changing pipeline scheduler')
|
||||
pipe.scheduler = scheduler.from_pretrained(
|
||||
scheduler = scheduler.from_pretrained(
|
||||
model, subfolder='scheduler')
|
||||
|
||||
if device is not None:
|
||||
scheduler = scheduler.to(device)
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
last_pipeline_scheduler = scheduler
|
||||
|
||||
print('running garbage collection during pipeline change')
|
||||
|
|
Loading…
Reference in New Issue