1
0
Fork 0

add experimental patch with latent mirroring

This commit is contained in:
Sean Sube 2024-01-21 21:34:58 -06:00
parent ff11d75784
commit dcaadf1a31
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 34 additions and 254 deletions

View File

@ -18,9 +18,11 @@ from ..params import DeviceParams, ImageParams
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..torch_before_ort import InferenceSession from ..torch_before_ort import InferenceSession
from ..utils import run_gc from ..utils import run_gc
from .patches.scheduler import SchedulerPatch
from .patches.unet import UNetWrapper from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.highres import OnnxStableDiffusionHighresPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
@ -53,6 +55,7 @@ logger = getLogger(__name__)
available_pipelines = { available_pipelines = {
"controlnet": OnnxStableDiffusionControlNetPipeline, "controlnet": OnnxStableDiffusionControlNetPipeline,
"highres": OnnxStableDiffusionHighresPipeline,
"img2img": OnnxStableDiffusionImg2ImgPipeline, "img2img": OnnxStableDiffusionImg2ImgPipeline,
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline, "img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
"inpaint": OnnxStableDiffusionInpaintPipeline, "inpaint": OnnxStableDiffusionInpaintPipeline,
@ -651,9 +654,13 @@ def patch_pipeline(
if not params.is_lpw() and not params.is_xl(): if not params.is_lpw() and not params.is_xl():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
logger.debug("patching pipeline scheduler")
original_scheduler = pipe.scheduler
pipe.scheduler = SchedulerPatch(original_scheduler)
logger.debug("patching pipeline UNet")
original_unet = pipe.unet original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet, params.is_xl()) pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
logger.debug("patched UNet with wrapper")
if hasattr(pipe, "vae_decoder"): if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder original_decoder = pipe.vae_decoder

View File

@ -13,18 +13,17 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from logging import getLogger
from math import ceil from math import ceil
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...chain.tile import make_tile_mask from ...chain.tile import make_tile_mask
@ -37,8 +36,9 @@ from ..utils import (
repair_nan, repair_nan,
resize_latent_shape, resize_latent_shape,
) )
from .base import OnnxStableDiffusionBasePipeline
logger = logging.get_logger(__name__) logger = getLogger(__name__)
# inpaint constants # inpaint constants
@ -96,18 +96,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape):
return mask, masked_image return mask, masked_image
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): class OnnxStableDiffusionPanoramaPipeline(OnnxStableDiffusionBasePipeline):
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
@ -122,65 +111,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
window: Optional[int] = None, window: Optional[int] = None,
stride: Optional[int] = None, stride: Optional[int] = None,
): ):
super().__init__() super().__init__(
self.window = window or DEFAULT_WINDOW
self.stride = stride or DEFAULT_STRIDE
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate(
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is True
):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate(
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -189,173 +120,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) requires_safety_checker=requires_safety_checker,
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="max_length", return_tensors="np"
).input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder( self.window = window or DEFAULT_WINDOW
input_ids=text_input_ids.astype(np.int32) self.stride = stride or DEFAULT_STRIDE
)[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input.input_ids.astype(np.int32)
)[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(
negative_prompt_embeds, num_images_per_prompt, axis=0
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def check_inputs(
self,
prompt: Union[str, List[str]],
height: Optional[int],
width: Optional[int],
callback_steps: int,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None
and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def get_views( def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int self, panorama_height: int, panorama_width: int, window_size: int, stride: int
@ -993,7 +762,6 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0) count.fill(0)
value.fill(0) value.fill(0)

View File

@ -835,7 +835,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0) count.fill(0)
value.fill(0) value.fill(0)

View File

@ -100,8 +100,8 @@ export function HorizontalBody(props: BodyProps) {
separator separator
snap snap
> >
<TabGroup direction={props.direction} /> <TabGroup direction={props.direction} panelClass='scroll-controls' />
<Box className='box-history' sx={layout.history.style}> <Box className='scroll-history' sx={layout.history.style}>
<ImageHistory width={props.width} /> <ImageHistory width={props.width} />
</Box> </Box>
</Allotment>; </Allotment>;
@ -113,7 +113,7 @@ export function VerticalBody(props: BodyProps) {
return <Stack direction={layout.direction} spacing={STANDARD_SPACING}> return <Stack direction={layout.direction} spacing={STANDARD_SPACING}>
<TabGroup direction={props.direction} /> <TabGroup direction={props.direction} />
<Divider flexItem variant='middle' orientation={layout.divider} /> <Divider flexItem variant='middle' orientation={layout.divider} />
<Box className='box-history' sx={layout.history.style}> <Box sx={layout.history.style}>
<ImageHistory width={props.width} /> <ImageHistory width={props.width} />
</Box> </Box>
</Stack>; </Stack>;
@ -121,6 +121,7 @@ export function VerticalBody(props: BodyProps) {
export interface TabGroupProps { export interface TabGroupProps {
direction: Layout; direction: Layout;
panelClass?: string;
} }
export function TabGroup(props: TabGroupProps) { export function TabGroup(props: TabGroupProps) {
@ -138,25 +139,25 @@ export function TabGroup(props: TabGroupProps) {
{TAB_LABELS.map((name) => <Tab key={name} label={t(`tab.${name}`)} value={name} />)} {TAB_LABELS.map((name) => <Tab key={name} label={t(`tab.${name}`)} value={name} />)}
</TabList> </TabList>
</Box> </Box>
<TabPanel value='txt2img'> <TabPanel className={props.panelClass} value='txt2img'>
<Txt2Img /> <Txt2Img />
</TabPanel> </TabPanel>
<TabPanel value='img2img'> <TabPanel className={props.panelClass} value='img2img'>
<Img2Img /> <Img2Img />
</TabPanel> </TabPanel>
<TabPanel value='inpaint'> <TabPanel className={props.panelClass} value='inpaint'>
<Inpaint /> <Inpaint />
</TabPanel> </TabPanel>
<TabPanel value='upscale'> <TabPanel className={props.panelClass} value='upscale'>
<Upscale /> <Upscale />
</TabPanel> </TabPanel>
<TabPanel value='blend'> <TabPanel className={props.panelClass} value='blend'>
<Blend /> <Blend />
</TabPanel> </TabPanel>
<TabPanel value='models'> <TabPanel className={props.panelClass} value='models'>
<Models /> <Models />
</TabPanel> </TabPanel>
<TabPanel value='settings'> <TabPanel className={props.panelClass} value='settings'>
<Settings /> <Settings />
</TabPanel> </TabPanel>
</TabContext> </TabContext>

View File

@ -2,7 +2,12 @@
height: 90vb; height: 90vb;
} }
.box-history { .scroll-history {
max-height: 90vh; max-height: 90vh;
overflow-y: auto; overflow-y: auto;
} }
.scroll-controls {
max-height: 85vh;
overflow-y: auto;
}