diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 761db0a0..9256f1aa 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -55,7 +55,7 @@ def get_tile_latents( def load_pipeline( pipeline: DiffusionPipeline, model: str, - scheduler: Any, + scheduler_type: Any, device: DeviceParams, lpw: bool, ): @@ -79,7 +79,7 @@ def load_pipeline( custom_pipeline = None logger.debug("loading new diffusion pipeline from %s", model) - scheduler = scheduler.from_pretrained( + scheduler = scheduler_type.from_pretrained( model, provider=device.provider, provider_options=device.options, @@ -100,11 +100,11 @@ def load_pipeline( last_pipeline_instance = pipe last_pipeline_options = options - last_pipeline_scheduler = scheduler + last_pipeline_scheduler = scheduler_type if last_pipeline_scheduler != scheduler: logger.debug("loading new diffusion scheduler") - scheduler = scheduler.from_pretrained( + scheduler = scheduler_type.from_pretrained( model, provider=device.provider, provider_options=device.options, @@ -112,10 +112,10 @@ def load_pipeline( ) if device is not None and hasattr(scheduler, "to"): - scheduler = scheduler.to(device) + scheduler = scheduler.to(device.torch_device()) pipe.scheduler = scheduler - last_pipeline_scheduler = scheduler + last_pipeline_scheduler = scheduler_type run_gc() return pipe