1
0
Fork 0

fix(api): correctly cache diffusers scheduler

This commit is contained in:
Sean Sube 2023-02-12 09:33:13 -06:00
parent 1179092028
commit 9c5043e9d0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 6 deletions

View File

@ -55,7 +55,7 @@ def get_tile_latents(
def load_pipeline( def load_pipeline(
pipeline: DiffusionPipeline, pipeline: DiffusionPipeline,
model: str, model: str,
scheduler: Any, scheduler_type: Any,
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
): ):
@ -79,7 +79,7 @@ def load_pipeline(
custom_pipeline = None custom_pipeline = None
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
provider_options=device.options, provider_options=device.options,
@ -100,11 +100,11 @@ def load_pipeline(
last_pipeline_instance = pipe last_pipeline_instance = pipe
last_pipeline_options = options last_pipeline_options = options
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler_type
if last_pipeline_scheduler != scheduler: if last_pipeline_scheduler != scheduler:
logger.debug("loading new diffusion scheduler") logger.debug("loading new diffusion scheduler")
scheduler = scheduler.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
provider_options=device.options, provider_options=device.options,
@ -112,10 +112,10 @@ def load_pipeline(
) )
if device is not None and hasattr(scheduler, "to"): if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device) scheduler = scheduler.to(device.torch_device())
pipe.scheduler = scheduler pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler last_pipeline_scheduler = scheduler_type
run_gc() run_gc()
return pipe return pipe