fix gradient point order
This commit is contained in:
parent
51217eae8a
commit
40dac93e18
|
@ -6,15 +6,18 @@ from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPatch:
|
class SchedulerPatch:
|
||||||
scheduler: Any
|
wrapped: Any
|
||||||
|
|
||||||
def __init__(self, scheduler):
|
def __init__(self, scheduler):
|
||||||
self.scheduler = scheduler
|
self.wrapped = scheduler
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
||||||
) -> DDIMSchedulerOutput:
|
) -> DDIMSchedulerOutput:
|
||||||
result = self.scheduler.step(model_output, timestep, sample)
|
result = self.wrapped.step(model_output, timestep, sample)
|
||||||
|
|
||||||
white_point = 0
|
white_point = 0
|
||||||
black_point = 8
|
black_point = 8
|
||||||
|
@ -22,7 +25,7 @@ class SchedulerPatch:
|
||||||
direction = "horizontal"
|
direction = "horizontal"
|
||||||
|
|
||||||
mirrored_latents = mirror_latents(
|
mirrored_latents = mirror_latents(
|
||||||
result.prev_sample, black_point, white_point, center_line, direction
|
result.prev_sample, white_point, black_point, center_line, direction
|
||||||
)
|
)
|
||||||
|
|
||||||
return DDIMSchedulerOutput(
|
return DDIMSchedulerOutput(
|
||||||
|
@ -33,14 +36,14 @@ class SchedulerPatch:
|
||||||
|
|
||||||
def mirror_latents(
|
def mirror_latents(
|
||||||
latents: np.ndarray,
|
latents: np.ndarray,
|
||||||
black_point: int,
|
|
||||||
white_point: int,
|
white_point: int,
|
||||||
|
black_point: int,
|
||||||
center_line: int,
|
center_line: int,
|
||||||
direction: Literal["horizontal", "vertical"],
|
direction: Literal["horizontal", "vertical"],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
gradient = np.linspace(1, 0, white_point - black_point).astype(np.float32)
|
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||||
gradient = np.pad(
|
gradient = np.pad(
|
||||||
gradient, (black_point, center_line - white_point), mode="constant"
|
gradient, (white_point, center_line - black_point), mode="constant"
|
||||||
)
|
)
|
||||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||||
gradient = np.expand_dims(gradient, (0, 1, 2))
|
gradient = np.expand_dims(gradient, (0, 1, 2))
|
||||||
|
|
|
@ -24,12 +24,12 @@ class SchedulerPatchTests(unittest.TestCase):
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
black_point = 0
|
white_point = 0
|
||||||
white_point = 1
|
black_point = 1
|
||||||
center_line = 2
|
center_line = 2
|
||||||
direction = "horizontal"
|
direction = "horizontal"
|
||||||
mirrored_latents = mirror_latents(
|
mirrored_latents = mirror_latents(
|
||||||
latents, black_point, white_point, center_line, direction
|
latents, white_point, black_point, center_line, direction
|
||||||
)
|
)
|
||||||
assert np.array_equal(mirrored_latents, latents)
|
assert np.array_equal(mirrored_latents, latents)
|
||||||
|
|
||||||
|
@ -41,12 +41,12 @@ class SchedulerPatchTests(unittest.TestCase):
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
black_point = 0
|
white_point = 0
|
||||||
white_point = 1
|
black_point = 1
|
||||||
center_line = 3
|
center_line = 3
|
||||||
direction = "vertical"
|
direction = "vertical"
|
||||||
mirrored_latents = mirror_latents(
|
mirrored_latents = mirror_latents(
|
||||||
latents, black_point, white_point, center_line, direction
|
latents, white_point, black_point, center_line, direction
|
||||||
)
|
)
|
||||||
assert np.array_equal(
|
assert np.array_equal(
|
||||||
mirrored_latents,
|
mirrored_latents,
|
||||||
|
|
Loading…
Reference in New Issue