fix tile latent axis
This commit is contained in:
parent
50d6dbb451
commit
2033cbb601
|
@ -33,14 +33,13 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
||||||
|
|
||||||
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray:
|
||||||
x, y, tile = dims
|
x, y, tile = dims
|
||||||
|
t = tile // 8
|
||||||
x = x // 8
|
x = x // 8
|
||||||
y = y // 8
|
y = y // 8
|
||||||
|
|
||||||
t = tile // 8
|
|
||||||
xt = x + t
|
xt = x + t
|
||||||
yt = y + t
|
yt = y + t
|
||||||
|
|
||||||
return full_latents[:,:,x:xt,y:yt]
|
return full_latents[:,:,y:yt,x:xt]
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
from .onnx_net import (
|
from .onnx_net import (
|
||||||
OnnxImage,
|
OnnxImage,
|
||||||
OnnxNet,
|
OnnxNet,
|
||||||
)
|
|
||||||
from .pipeline_onnx_stable_diffusion_upscale import (
|
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
|
||||||
)
|
)
|
Loading…
Reference in New Issue