fix(api): convert latents to numpy before using
This commit is contained in:
parent
98f8abbacd
commit
8d4410305e
|
@ -5,13 +5,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||||
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
||||||
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
||||||
StableDiffusionXLImg2ImgPipelineMixin,
|
StableDiffusionXLImg2ImgPipelineMixin,
|
||||||
)
|
)
|
||||||
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
|
||||||
|
|
||||||
from onnx_web.chain.tile import make_tile_mask
|
from onnx_web.chain.tile import make_tile_mask
|
||||||
|
|
||||||
|
@ -730,7 +730,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
# 3. Preprocess image
|
# 3. Preprocess image
|
||||||
processor = VaeImageProcessor()
|
processor = VaeImageProcessor()
|
||||||
image = processor.preprocess(image)
|
image = processor.preprocess(image).cpu().numpy()
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
|
@ -385,7 +385,12 @@ def run_inpaint_pipeline(
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain.run(
|
||||||
worker, server, params, StageResult(images=[source]), callback=progress, latents=latents
|
worker,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
StageResult(images=[source]),
|
||||||
|
callback=progress,
|
||||||
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
Loading…
Reference in New Issue