diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index 58cab46c..b8a23bee 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -6,7 +6,7 @@ from diffusers import ( OnnxStableDiffusionInpaintPipeline, ) from PIL import Image, ImageChops -from typing import Any +from typing import Any, Union import gc import numpy as np @@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray: return image_latents -def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any): +def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options @@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu safety_checker=None, scheduler=scheduler.from_pretrained(model, subfolder='scheduler') ) + + if device is not None: + pipe = pipe.to(device) + last_pipeline_instance = pipe last_pipeline_options = options last_pipeline_scheduler = scheduler if last_pipeline_scheduler != scheduler: print('changing pipeline scheduler') - pipe.scheduler = scheduler.from_pretrained( + scheduler = scheduler.from_pretrained( model, subfolder='scheduler') + + if device is not None: + scheduler = scheduler.to(device) + + pipe.scheduler = scheduler last_pipeline_scheduler = scheduler print('running garbage collection during pipeline change')