fix tile latent axis
This commit is contained in:
parent
50d6dbb451
commit
2033cbb601
|
@ -33,14 +33,13 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
|||
|
||||
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
||||
x, y, tile = dims
|
||||
t = tile // 8
|
||||
x = x // 8
|
||||
y = y // 8
|
||||
|
||||
t = tile // 8
|
||||
xt = x + t
|
||||
yt = y + t
|
||||
|
||||
return full_latents[:,:,x:xt,y:yt]
|
||||
return full_latents[:,:,y:yt,x:xt]
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
from .onnx_net import (
|
||||
OnnxImage,
|
||||
OnnxNet,
|
||||
)
|
||||
from .pipeline_onnx_stable_diffusion_upscale import (
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
Loading…
Reference in New Issue