fix gradient point order
This commit is contained in:
parent
51217eae8a
commit
40dac93e18
|
@ -6,15 +6,18 @@ from torch import FloatTensor, Tensor
|
|||
|
||||
|
||||
class SchedulerPatch:
|
||||
scheduler: Any
|
||||
wrapped: Any
|
||||
|
||||
def __init__(self, scheduler):
|
||||
self.scheduler = scheduler
|
||||
self.wrapped = scheduler
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
def step(
|
||||
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
||||
) -> DDIMSchedulerOutput:
|
||||
result = self.scheduler.step(model_output, timestep, sample)
|
||||
result = self.wrapped.step(model_output, timestep, sample)
|
||||
|
||||
white_point = 0
|
||||
black_point = 8
|
||||
|
@ -22,7 +25,7 @@ class SchedulerPatch:
|
|||
direction = "horizontal"
|
||||
|
||||
mirrored_latents = mirror_latents(
|
||||
result.prev_sample, black_point, white_point, center_line, direction
|
||||
result.prev_sample, white_point, black_point, center_line, direction
|
||||
)
|
||||
|
||||
return DDIMSchedulerOutput(
|
||||
|
@ -33,14 +36,14 @@ class SchedulerPatch:
|
|||
|
||||
def mirror_latents(
|
||||
latents: np.ndarray,
|
||||
black_point: int,
|
||||
white_point: int,
|
||||
black_point: int,
|
||||
center_line: int,
|
||||
direction: Literal["horizontal", "vertical"],
|
||||
) -> np.ndarray:
|
||||
gradient = np.linspace(1, 0, white_point - black_point).astype(np.float32)
|
||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||
gradient = np.pad(
|
||||
gradient, (black_point, center_line - white_point), mode="constant"
|
||||
gradient, (white_point, center_line - black_point), mode="constant"
|
||||
)
|
||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||
gradient = np.expand_dims(gradient, (0, 1, 2))
|
||||
|
|
|
@ -24,12 +24,12 @@ class SchedulerPatchTests(unittest.TestCase):
|
|||
],
|
||||
]
|
||||
)
|
||||
black_point = 0
|
||||
white_point = 1
|
||||
white_point = 0
|
||||
black_point = 1
|
||||
center_line = 2
|
||||
direction = "horizontal"
|
||||
mirrored_latents = mirror_latents(
|
||||
latents, black_point, white_point, center_line, direction
|
||||
latents, white_point, black_point, center_line, direction
|
||||
)
|
||||
assert np.array_equal(mirrored_latents, latents)
|
||||
|
||||
|
@ -41,12 +41,12 @@ class SchedulerPatchTests(unittest.TestCase):
|
|||
],
|
||||
]
|
||||
)
|
||||
black_point = 0
|
||||
white_point = 1
|
||||
white_point = 0
|
||||
black_point = 1
|
||||
center_line = 3
|
||||
direction = "vertical"
|
||||
mirrored_latents = mirror_latents(
|
||||
latents, black_point, white_point, center_line, direction
|
||||
latents, white_point, black_point, center_line, direction
|
||||
)
|
||||
assert np.array_equal(
|
||||
mirrored_latents,
|
||||
|
|
Loading…
Reference in New Issue