1
0
Fork 0

fix gradient point order

This commit is contained in:
Sean Sube 2024-01-21 21:45:51 -06:00
parent 51217eae8a
commit 40dac93e18
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 13 deletions

View File

@ -6,15 +6,18 @@ from torch import FloatTensor, Tensor
class SchedulerPatch: class SchedulerPatch:
scheduler: Any wrapped: Any
def __init__(self, scheduler): def __init__(self, scheduler):
self.scheduler = scheduler self.wrapped = scheduler
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def step( def step(
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
) -> DDIMSchedulerOutput: ) -> DDIMSchedulerOutput:
result = self.scheduler.step(model_output, timestep, sample) result = self.wrapped.step(model_output, timestep, sample)
white_point = 0 white_point = 0
black_point = 8 black_point = 8
@ -22,7 +25,7 @@ class SchedulerPatch:
direction = "horizontal" direction = "horizontal"
mirrored_latents = mirror_latents( mirrored_latents = mirror_latents(
result.prev_sample, black_point, white_point, center_line, direction result.prev_sample, white_point, black_point, center_line, direction
) )
return DDIMSchedulerOutput( return DDIMSchedulerOutput(
@ -33,14 +36,14 @@ class SchedulerPatch:
def mirror_latents( def mirror_latents(
latents: np.ndarray, latents: np.ndarray,
black_point: int,
white_point: int, white_point: int,
black_point: int,
center_line: int, center_line: int,
direction: Literal["horizontal", "vertical"], direction: Literal["horizontal", "vertical"],
) -> np.ndarray: ) -> np.ndarray:
gradient = np.linspace(1, 0, white_point - black_point).astype(np.float32) gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
gradient = np.pad( gradient = np.pad(
gradient, (black_point, center_line - white_point), mode="constant" gradient, (white_point, center_line - black_point), mode="constant"
) )
gradient = np.reshape([gradient, np.flip(gradient)], -1) gradient = np.reshape([gradient, np.flip(gradient)], -1)
gradient = np.expand_dims(gradient, (0, 1, 2)) gradient = np.expand_dims(gradient, (0, 1, 2))

View File

@ -24,12 +24,12 @@ class SchedulerPatchTests(unittest.TestCase):
], ],
] ]
) )
black_point = 0 white_point = 0
white_point = 1 black_point = 1
center_line = 2 center_line = 2
direction = "horizontal" direction = "horizontal"
mirrored_latents = mirror_latents( mirrored_latents = mirror_latents(
latents, black_point, white_point, center_line, direction latents, white_point, black_point, center_line, direction
) )
assert np.array_equal(mirrored_latents, latents) assert np.array_equal(mirrored_latents, latents)
@ -41,12 +41,12 @@ class SchedulerPatchTests(unittest.TestCase):
], ],
] ]
) )
black_point = 0 white_point = 0
white_point = 1 black_point = 1
center_line = 3 center_line = 3
direction = "vertical" direction = "vertical"
mirrored_latents = mirror_latents( mirrored_latents = mirror_latents(
latents, black_point, white_point, center_line, direction latents, white_point, black_point, center_line, direction
) )
assert np.array_equal( assert np.array_equal(
mirrored_latents, mirrored_latents,