1
0
Fork 0

fix(api): convert latents to numpy before using

This commit is contained in:
Sean Sube 2023-11-24 10:36:53 -06:00
parent 98f8abbacd
commit 8d4410305e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 8 additions and 3 deletions

View File

@ -5,13 +5,13 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import ( from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
StableDiffusionXLImg2ImgPipelineMixin, StableDiffusionXLImg2ImgPipelineMixin,
) )
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from diffusers.image_processor import VaeImageProcessor
from onnx_web.chain.tile import make_tile_mask from onnx_web.chain.tile import make_tile_mask
@ -730,7 +730,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 3. Preprocess image # 3. Preprocess image
processor = VaeImageProcessor() processor = VaeImageProcessor()
image = processor.preprocess(image) image = processor.preprocess(image).cpu().numpy()
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)

View File

@ -385,7 +385,12 @@ def run_inpaint_pipeline(
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress, latents=latents worker,
server,
params,
StageResult(images=[source]),
callback=progress,
latents=latents,
) )
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)