1
0
Fork 0

fix(api): remove size limit on inpainting stage

This commit is contained in:
Sean Sube 2023-04-29 14:23:00 -05:00
parent 0666d81b66
commit f782f39cce
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 28 additions and 6 deletions

View File

@ -56,7 +56,7 @@ def upscale_outpaint(
draw_mask = ImageDraw.Draw(stage_mask) draw_mask = ImageDraw.Draw(stage_mask)
full_size = Size(*full_dims) full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size) full_latents = get_latents_from_seed(params.seed, full_size.latent_size())
if is_debug(): if is_debug():
save_image(server, "last-source.png", source) save_image(server, "last-source.png", source)

View File

@ -21,6 +21,19 @@ class TileCallback(Protocol):
pass pass
def complete_tile(
source: Image.Image,
tile: int,
) -> Image.Image:
if source.width < tile or source.height < tile:
full_source = Image.new(source.mode, (tile, tile))
full_source.paste(source)
return full_source
return source
def process_tile_grid( def process_tile_grid(
source: Image.Image, source: Image.Image,
tile: int, tile: int,
@ -29,10 +42,10 @@ def process_tile_grid(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
width, height = source.size width, height = source.size
image = Image.new("RGB", (width * scale, height * scale)) image = Image.new(source.mode, (width * scale, height * scale))
tiles_x = width // tile tiles_x = ceil(width / tile)
tiles_y = height // tile tiles_y = ceil(height / tile)
total = tiles_x * tiles_y total = tiles_x * tiles_y
for y in range(tiles_y): for y in range(tiles_y):
@ -41,7 +54,9 @@ def process_tile_grid(
left = x * tile left = x * tile
top = y * tile top = y * tile
logger.debug("processing tile %s of %s, %s.%s", idx + 1, total, y, x) logger.debug("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = source.crop((left, top, left + tile, top + tile)) tile_image = source.crop((left, top, left + tile, top + tile))
tile_image = complete_tile(tile_image)
for filter in filters: for filter in filters:
tile_image = filter(tile_image, (left, top, tile)) tile_image = filter(tile_image, (left, top, tile))
@ -74,6 +89,7 @@ def process_tile_spiral(
logger.debug("processing tile %s of %s, %sx%s", counter, len(tiles), left, top) logger.debug("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
tile_image = image.crop((left, top, left + tile, top + tile)) tile_image = image.crop((left, top, left + tile, top + tile))
tile_image = complete_tile(tile_image)
for filter in filters: for filter in filters:
tile_image = filter(tile_image, (left, top, tile)) tile_image = filter(tile_image, (left, top, tile))

View File

@ -22,7 +22,6 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
@ -42,7 +41,7 @@ available_pipelines = {
"img2img": StableDiffusionPipeline, "img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline, "inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline, "lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPanoramaPipeline, "panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline, "pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline, "txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline, "upscale": StableDiffusionUpscalePipeline,

View File

@ -1,5 +1,6 @@
from enum import IntEnum from enum import IntEnum
from logging import getLogger from logging import getLogger
from math import ceil
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from .models.meta import NetworkModel from .models.meta import NetworkModel
@ -78,6 +79,12 @@ class Size:
border.top + self.height + border.bottom, border.top + self.height + border.bottom,
) )
def latent_size(self):
return Size(
ceil(self.width / 8),
ceil(self.height / 8),
)
def tojson(self) -> Dict[str, int]: def tojson(self) -> Dict[str, int]:
return { return {
"height": self.height, "height": self.height,