1
0
Fork 0

put latent mirroring behind feature flag, fix means

This commit is contained in:
Sean Sube 2024-01-27 09:25:00 -06:00
parent bdf6f401a6
commit 2773ab0965
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 34 additions and 15 deletions

View File

@ -652,11 +652,12 @@ def patch_pipeline(
logger.debug("patching SD pipeline")
if not params.is_lpw() and not params.is_xl():
logger.debug("patching prompt encoder")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
logger.debug("patching pipeline scheduler")
original_scheduler = pipe.scheduler
pipe.scheduler = SchedulerPatch(original_scheduler)
pipe.scheduler = SchedulerPatch(server, original_scheduler)
logger.debug("patching pipeline UNet")
original_unet = pipe.unet

View File

@ -1,3 +1,4 @@
from logging import getLogger
from typing import Any, Literal
import numpy as np
@ -5,11 +6,17 @@ import torch
from diffusers.schedulers.scheduling_utils import SchedulerOutput
from torch import FloatTensor, Tensor
from ...server.context import ServerContext
logger = getLogger(__name__)
class SchedulerPatch:
server: ServerContext
wrapped: Any
def __init__(self, scheduler):
def __init__(self, server: ServerContext, scheduler):
self.server = server
self.wrapped = scheduler
def __getattr__(self, attr):
@ -20,22 +27,33 @@ class SchedulerPatch:
) -> SchedulerOutput:
result = self.wrapped.step(model_output, timestep, sample)
white_point = result.prev_sample.shape[2] // 8
black_point = result.prev_sample.shape[2] // 4
center_line = result.prev_sample.shape[2] // 2
# direction = "horizontal"
if self.server.has_feature("mirror-latents"):
logger.info("using experimental latent mirroring")
gradient = linear_gradient(white_point, black_point, center_line)
latents = result.prev_sample.numpy()
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
white_point = 0
black_point = result.prev_sample.shape[2] // 8
center_line = result.prev_sample.shape[2] // 2
# mirrored_latents = mirror_latents(
# result.prev_sample.numpy(), gradient, center_line, direction
# )
gradient = linear_gradient(white_point, black_point, center_line)
latents = result.prev_sample.numpy()
return SchedulerOutput(
prev_sample=torch.from_numpy(latents),
)
gradiated_latents = np.multiply(latents, gradient)
inverse_gradiated_latents = np.multiply(np.flip(latents, axis=3), gradient)
latents += gradiated_latents + inverse_gradiated_latents
mask = np.ones_like(latents).astype(np.float32)
gradiated_mask = np.multiply(mask, gradient)
# flipping the mask would do nothing, we need to flip the gradient for this one
inverse_gradiated_mask = np.multiply(mask, np.flip(gradient, axis=3))
mask += gradiated_mask + inverse_gradiated_mask
latents = np.where(mask > 0, latents / mask, latents)
return SchedulerOutput(
prev_sample=torch.from_numpy(latents),
)
else:
return result
def linear_gradient(