diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index 643f5bf1..a7c10957 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -6,15 +6,18 @@ from torch import FloatTensor, Tensor class SchedulerPatch: - scheduler: Any + wrapped: Any def __init__(self, scheduler): - self.scheduler = scheduler + self.wrapped = scheduler + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) def step( self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor ) -> DDIMSchedulerOutput: - result = self.scheduler.step(model_output, timestep, sample) + result = self.wrapped.step(model_output, timestep, sample) white_point = 0 black_point = 8 @@ -22,7 +25,7 @@ class SchedulerPatch: direction = "horizontal" mirrored_latents = mirror_latents( - result.prev_sample, black_point, white_point, center_line, direction + result.prev_sample, white_point, black_point, center_line, direction ) return DDIMSchedulerOutput( @@ -33,14 +36,14 @@ class SchedulerPatch: def mirror_latents( latents: np.ndarray, - black_point: int, white_point: int, + black_point: int, center_line: int, direction: Literal["horizontal", "vertical"], ) -> np.ndarray: - gradient = np.linspace(1, 0, white_point - black_point).astype(np.float32) + gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32) gradient = np.pad( - gradient, (black_point, center_line - white_point), mode="constant" + gradient, (white_point, center_line - black_point), mode="constant" ) gradient = np.reshape([gradient, np.flip(gradient)], -1) gradient = np.expand_dims(gradient, (0, 1, 2)) diff --git a/api/tests/test_diffusers/test_scheduler.py b/api/tests/test_diffusers/test_scheduler.py index 30e6d34b..0439722d 100644 --- a/api/tests/test_diffusers/test_scheduler.py +++ b/api/tests/test_diffusers/test_scheduler.py @@ -24,12 +24,12 @@ class SchedulerPatchTests(unittest.TestCase): ], ] ) - black_point = 0 - white_point = 1 + white_point = 0 + black_point = 1 center_line = 2 direction = "horizontal" mirrored_latents = mirror_latents( - latents, black_point, white_point, center_line, direction + latents, white_point, black_point, center_line, direction ) assert np.array_equal(mirrored_latents, latents) @@ -41,12 +41,12 @@ class SchedulerPatchTests(unittest.TestCase): ], ] ) - black_point = 0 - white_point = 1 + white_point = 0 + black_point = 1 center_line = 3 direction = "vertical" mirrored_latents = mirror_latents( - latents, black_point, white_point, center_line, direction + latents, white_point, black_point, center_line, direction ) assert np.array_equal( mirrored_latents,