fix(api): correctly cache diffusers scheduler
This commit is contained in:
parent
1179092028
commit
9c5043e9d0
|
@ -55,7 +55,7 @@ def get_tile_latents(
|
|||
def load_pipeline(
|
||||
pipeline: DiffusionPipeline,
|
||||
model: str,
|
||||
scheduler: Any,
|
||||
scheduler_type: Any,
|
||||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
):
|
||||
|
@ -79,7 +79,7 @@ def load_pipeline(
|
|||
custom_pipeline = None
|
||||
|
||||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
scheduler = scheduler.from_pretrained(
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
|
@ -100,11 +100,11 @@ def load_pipeline(
|
|||
|
||||
last_pipeline_instance = pipe
|
||||
last_pipeline_options = options
|
||||
last_pipeline_scheduler = scheduler
|
||||
last_pipeline_scheduler = scheduler_type
|
||||
|
||||
if last_pipeline_scheduler != scheduler:
|
||||
logger.debug("loading new diffusion scheduler")
|
||||
scheduler = scheduler.from_pretrained(
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
|
@ -112,10 +112,10 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
scheduler = scheduler.to(device)
|
||||
scheduler = scheduler.to(device.torch_device())
|
||||
|
||||
pipe.scheduler = scheduler
|
||||
last_pipeline_scheduler = scheduler
|
||||
last_pipeline_scheduler = scheduler_type
|
||||
run_gc()
|
||||
|
||||
return pipe
|
||||
|
|
Loading…
Reference in New Issue