1
0
Fork 0

fix(api): correctly cache diffusers scheduler

This commit is contained in:
Sean Sube 2023-02-12 09:33:13 -06:00
parent 1179092028
commit 9c5043e9d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 6 deletions

View File

@ -55,7 +55,7 @@ def get_tile_latents(
def load_pipeline(
pipeline: DiffusionPipeline,
model: str,
scheduler: Any,
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
):
@ -79,7 +79,7 @@ def load_pipeline(
custom_pipeline = None
logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained(
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
@ -100,11 +100,11 @@ def load_pipeline(
last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler
last_pipeline_scheduler = scheduler_type
if last_pipeline_scheduler != scheduler:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained(
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
@ -112,10 +112,10 @@ def load_pipeline(
)
if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device)
scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler
last_pipeline_scheduler = scheduler_type
run_gc()
return pipe