put latent mirroring behind feature flag, fix means
This commit is contained in:
parent
bdf6f401a6
commit
2773ab0965
|
@ -652,11 +652,12 @@ def patch_pipeline(
|
||||||
logger.debug("patching SD pipeline")
|
logger.debug("patching SD pipeline")
|
||||||
|
|
||||||
if not params.is_lpw() and not params.is_xl():
|
if not params.is_lpw() and not params.is_xl():
|
||||||
|
logger.debug("patching prompt encoder")
|
||||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||||
|
|
||||||
logger.debug("patching pipeline scheduler")
|
logger.debug("patching pipeline scheduler")
|
||||||
original_scheduler = pipe.scheduler
|
original_scheduler = pipe.scheduler
|
||||||
pipe.scheduler = SchedulerPatch(original_scheduler)
|
pipe.scheduler = SchedulerPatch(server, original_scheduler)
|
||||||
|
|
||||||
logger.debug("patching pipeline UNet")
|
logger.debug("patching pipeline UNet")
|
||||||
original_unet = pipe.unet
|
original_unet = pipe.unet
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from logging import getLogger
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -5,11 +6,17 @@ import torch
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||||
from torch import FloatTensor, Tensor
|
from torch import FloatTensor, Tensor
|
||||||
|
|
||||||
|
from ...server.context import ServerContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerPatch:
|
class SchedulerPatch:
|
||||||
|
server: ServerContext
|
||||||
wrapped: Any
|
wrapped: Any
|
||||||
|
|
||||||
def __init__(self, scheduler):
|
def __init__(self, server: ServerContext, scheduler):
|
||||||
|
self.server = server
|
||||||
self.wrapped = scheduler
|
self.wrapped = scheduler
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
|
@ -20,22 +27,33 @@ class SchedulerPatch:
|
||||||
) -> SchedulerOutput:
|
) -> SchedulerOutput:
|
||||||
result = self.wrapped.step(model_output, timestep, sample)
|
result = self.wrapped.step(model_output, timestep, sample)
|
||||||
|
|
||||||
white_point = result.prev_sample.shape[2] // 8
|
if self.server.has_feature("mirror-latents"):
|
||||||
black_point = result.prev_sample.shape[2] // 4
|
logger.info("using experimental latent mirroring")
|
||||||
|
|
||||||
|
white_point = 0
|
||||||
|
black_point = result.prev_sample.shape[2] // 8
|
||||||
center_line = result.prev_sample.shape[2] // 2
|
center_line = result.prev_sample.shape[2] // 2
|
||||||
# direction = "horizontal"
|
|
||||||
|
|
||||||
gradient = linear_gradient(white_point, black_point, center_line)
|
gradient = linear_gradient(white_point, black_point, center_line)
|
||||||
latents = result.prev_sample.numpy()
|
latents = result.prev_sample.numpy()
|
||||||
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
|
|
||||||
|
|
||||||
# mirrored_latents = mirror_latents(
|
gradiated_latents = np.multiply(latents, gradient)
|
||||||
# result.prev_sample.numpy(), gradient, center_line, direction
|
inverse_gradiated_latents = np.multiply(np.flip(latents, axis=3), gradient)
|
||||||
# )
|
latents += gradiated_latents + inverse_gradiated_latents
|
||||||
|
|
||||||
|
mask = np.ones_like(latents).astype(np.float32)
|
||||||
|
gradiated_mask = np.multiply(mask, gradient)
|
||||||
|
# flipping the mask would do nothing, we need to flip the gradient for this one
|
||||||
|
inverse_gradiated_mask = np.multiply(mask, np.flip(gradient, axis=3))
|
||||||
|
mask += gradiated_mask + inverse_gradiated_mask
|
||||||
|
|
||||||
|
latents = np.where(mask > 0, latents / mask, latents)
|
||||||
|
|
||||||
return SchedulerOutput(
|
return SchedulerOutput(
|
||||||
prev_sample=torch.from_numpy(latents),
|
prev_sample=torch.from_numpy(latents),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def linear_gradient(
|
def linear_gradient(
|
||||||
|
|
Loading…
Reference in New Issue