feat(api): backend support for multiple GPUs in diffusion pipelines
This commit is contained in:
parent
88c5113e37
commit
a868c8cf6b
|
@ -6,7 +6,7 @@ from diffusers import (
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
)
|
)
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
|
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
|
||||||
last_pipeline_instance = pipe
|
last_pipeline_instance = pipe
|
||||||
last_pipeline_options = options
|
last_pipeline_options = options
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
print('changing pipeline scheduler')
|
print('changing pipeline scheduler')
|
||||||
pipe.scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model, subfolder='scheduler')
|
model, subfolder='scheduler')
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
scheduler = scheduler.to(device)
|
||||||
|
|
||||||
|
pipe.scheduler = scheduler
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler
|
||||||
|
|
||||||
print('running garbage collection during pipeline change')
|
print('running garbage collection during pipeline change')
|
||||||
|
|
Loading…
Reference in New Issue