fix(api): improve handling of non-square images around tile size
This commit is contained in:
parent
95cad909fc
commit
d9dd1e4b57
|
@ -6,12 +6,7 @@ import torch
|
|||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import (
|
||||
encode_prompt,
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
parse_prompt,
|
||||
)
|
||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
@ -54,23 +49,13 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
)
|
||||
|
||||
tile_size = params.tiles
|
||||
|
||||
if max(size) > tile_size:
|
||||
latent_size = Size(tile_size, tile_size)
|
||||
pipe_width = pipe_height = tile_size
|
||||
else:
|
||||
latent_size = Size(size.width, size.height)
|
||||
pipe_width = size.width
|
||||
pipe_height = size.height
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
# generate new latents or slice existing
|
||||
if latents is None:
|
||||
# generate new latents
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
else:
|
||||
# slice existing latents
|
||||
latents = get_tile_latents(latents, dims, Size(tile_size, tile_size))
|
||||
pipe_width = pipe_height = tile_size
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
pipe = load_pipeline(
|
||||
|
@ -87,8 +72,8 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
|
@ -108,8 +93,8 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
|
|
|
@ -71,24 +71,15 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
outputs.append(source)
|
||||
continue
|
||||
|
||||
size = Size(*source.size)
|
||||
tile_size = params.tiles
|
||||
if max(size) > tile_size:
|
||||
latent_size = Size(tile_size, tile_size)
|
||||
pipe_width = pipe_height = tile_size
|
||||
else:
|
||||
latent_size = Size(size.width, size.height)
|
||||
pipe_width = size.width
|
||||
pipe_height = size.height
|
||||
size = Size(*source.size)
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
# generate new latents or slice existing
|
||||
if latents is None:
|
||||
# generate new latents
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
else:
|
||||
# slice existing latents and make sure there is a complete tile
|
||||
latents = get_tile_latents(latents, dims, Size(tile_size, tile_size))
|
||||
pipe_width = pipe_height = tile_size
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
|
@ -98,8 +89,8 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
tile_mask,
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
generator=rng,
|
||||
|
@ -119,8 +110,8 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
source,
|
||||
tile_mask,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
generator=rng,
|
||||
|
|
|
@ -81,6 +81,9 @@ class Size:
|
|||
border.top + self.height + border.bottom,
|
||||
)
|
||||
|
||||
def min(self, width: int, height: int):
|
||||
return Size(min(self.width, width), min(self.height, height))
|
||||
|
||||
def round_to_tile(self, tile=512):
|
||||
return Size(
|
||||
ceil(self.width / tile) * tile,
|
||||
|
|
Loading…
Reference in New Issue