put latent mirroring behind feature flag, fix means
This commit is contained in:
parent
bdf6f401a6
commit
2773ab0965
|
@ -652,11 +652,12 @@ def patch_pipeline(
|
|||
logger.debug("patching SD pipeline")
|
||||
|
||||
if not params.is_lpw() and not params.is_xl():
|
||||
logger.debug("patching prompt encoder")
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
||||
logger.debug("patching pipeline scheduler")
|
||||
original_scheduler = pipe.scheduler
|
||||
pipe.scheduler = SchedulerPatch(original_scheduler)
|
||||
pipe.scheduler = SchedulerPatch(server, original_scheduler)
|
||||
|
||||
logger.debug("patching pipeline UNet")
|
||||
original_unet = pipe.unet
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
@ -5,11 +6,17 @@ import torch
|
|||
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
from ...server.context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerPatch:
|
||||
server: ServerContext
|
||||
wrapped: Any
|
||||
|
||||
def __init__(self, scheduler):
|
||||
def __init__(self, server: ServerContext, scheduler):
|
||||
self.server = server
|
||||
self.wrapped = scheduler
|
||||
|
||||
def __getattr__(self, attr):
|
||||
|
@ -20,22 +27,33 @@ class SchedulerPatch:
|
|||
) -> SchedulerOutput:
|
||||
result = self.wrapped.step(model_output, timestep, sample)
|
||||
|
||||
white_point = result.prev_sample.shape[2] // 8
|
||||
black_point = result.prev_sample.shape[2] // 4
|
||||
if self.server.has_feature("mirror-latents"):
|
||||
logger.info("using experimental latent mirroring")
|
||||
|
||||
white_point = 0
|
||||
black_point = result.prev_sample.shape[2] // 8
|
||||
center_line = result.prev_sample.shape[2] // 2
|
||||
# direction = "horizontal"
|
||||
|
||||
gradient = linear_gradient(white_point, black_point, center_line)
|
||||
latents = result.prev_sample.numpy()
|
||||
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
|
||||
|
||||
# mirrored_latents = mirror_latents(
|
||||
# result.prev_sample.numpy(), gradient, center_line, direction
|
||||
# )
|
||||
gradiated_latents = np.multiply(latents, gradient)
|
||||
inverse_gradiated_latents = np.multiply(np.flip(latents, axis=3), gradient)
|
||||
latents += gradiated_latents + inverse_gradiated_latents
|
||||
|
||||
mask = np.ones_like(latents).astype(np.float32)
|
||||
gradiated_mask = np.multiply(mask, gradient)
|
||||
# flipping the mask would do nothing, we need to flip the gradient for this one
|
||||
inverse_gradiated_mask = np.multiply(mask, np.flip(gradient, axis=3))
|
||||
mask += gradiated_mask + inverse_gradiated_mask
|
||||
|
||||
latents = np.where(mask > 0, latents / mask, latents)
|
||||
|
||||
return SchedulerOutput(
|
||||
prev_sample=torch.from_numpy(latents),
|
||||
)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def linear_gradient(
|
||||
|
|
Loading…
Reference in New Issue