fix(api): correctly cache diffusers scheduler
This commit is contained in:
parent
1179092028
commit
9c5043e9d0
|
@ -55,7 +55,7 @@ def get_tile_latents(
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
pipeline: DiffusionPipeline,
|
pipeline: DiffusionPipeline,
|
||||||
model: str,
|
model: str,
|
||||||
scheduler: Any,
|
scheduler_type: Any,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
):
|
):
|
||||||
|
@ -79,7 +79,7 @@ def load_pipeline(
|
||||||
custom_pipeline = None
|
custom_pipeline = None
|
||||||
|
|
||||||
logger.debug("loading new diffusion pipeline from %s", model)
|
logger.debug("loading new diffusion pipeline from %s", model)
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler_type.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
provider_options=device.options,
|
provider_options=device.options,
|
||||||
|
@ -100,11 +100,11 @@ def load_pipeline(
|
||||||
|
|
||||||
last_pipeline_instance = pipe
|
last_pipeline_instance = pipe
|
||||||
last_pipeline_options = options
|
last_pipeline_options = options
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler_type
|
||||||
|
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
logger.debug("loading new diffusion scheduler")
|
logger.debug("loading new diffusion scheduler")
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler_type.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
provider_options=device.options,
|
provider_options=device.options,
|
||||||
|
@ -112,10 +112,10 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, "to"):
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
scheduler = scheduler.to(device)
|
scheduler = scheduler.to(device.torch_device())
|
||||||
|
|
||||||
pipe.scheduler = scheduler
|
pipe.scheduler = scheduler
|
||||||
last_pipeline_scheduler = scheduler
|
last_pipeline_scheduler = scheduler_type
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
Loading…
Reference in New Issue