diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 0e5acfa1..51667b5a 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -105,6 +105,7 @@ class ChainPipeline: len(sources), ) else: + sources = [None] logger.info("running pipeline without source images") stage_sources = sources @@ -112,20 +113,11 @@ class ChainPipeline: name = stage_params.name or stage_pipe.__class__.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} - - if len(stage_sources) > 0: - logger.debug( - "running stage %s with %s source images, parameters: %s", - name, - len(stage_sources), - kwargs.keys(), - ) - else: - logger.debug( - "running stage %s without source images, parameters: %s", - name, - kwargs.keys(), - ) + logger.debug( + "running stage %s, parameters: %s", + name, + kwargs.keys(), + ) # the stage must be split and tiled if any image is larger than the selected/max tile size must_tile = any( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 3042a114..18058ad9 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -32,6 +32,7 @@ class SourceTxt2ImgStage(BaseStage): ) -> Image.Image: params = params.with_args(**kwargs) size = size.with_args(**kwargs) + logger.info( "generating image using txt2img, %s steps: %s", params.steps, params.prompt ) @@ -45,7 +46,17 @@ class SourceTxt2ImgStage(BaseStage): params ) - latents = get_latents_from_seed(params.seed, size, params.batch) + tile_size = params.tiles + if max(size) > tile_size: + latent_size = Size(tile_size, tile_size) + latents = get_latents_from_seed(params.seed, latent_size, params.batch) + pipe_width = pipe_height = tile_size + else: + latent_size = Size(size.width, size.height) + latents = get_latents_from_seed(params.seed, latent_size, params.batch) + pipe_width = size.width + pipe_height = size.height + pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( server, @@ -61,8 +72,8 @@ class SourceTxt2ImgStage(BaseStage): rng = torch.manual_seed(params.seed) result = pipe.text2img( prompt, - height=size.height, - width=size.width, + height=pipe_height, + width=pipe_width, generator=rng, guidance_scale=params.cfg, latents=latents, @@ -82,8 +93,8 @@ class SourceTxt2ImgStage(BaseStage): rng = np.random.RandomState(params.seed) result = pipe( prompt, - height=size.height, - width=size.width, + height=pipe_height, + width=pipe_width, generator=rng, guidance_scale=params.cfg, latents=latents, diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 4714f81e..5e3aa2c7 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -233,13 +233,17 @@ def process_tile_spiral( ) -> Image.Image: width, height = kwargs.get("size", source.size if source else None) - # spiral uses the previous run and needs a scratch texture for 3x memory - tiles: List[Tuple[int, int, Image.Image]] = [] # tile tuples is source, multiply by scale for dest counter = 0 tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap) + + if len(tile_coords) == 1: + single_tile = True + else: + single_tile = False + for left, top in tile_coords: counter += 1 logger.info( @@ -266,21 +270,29 @@ def process_tile_spiral( bottom_margin = height - bottom if needs_margin: - base_image = ( - source.crop( - ( - left + left_margin, - top + top_margin, - right - right_margin, - bottom - bottom_margin, + # in the special case where the image is smaller than the specified tile size, just use the image + if single_tile: + logger.debug("creating and processing single-tile subtile") + tile_image = source + # otherwise use add histogram noise outside of the image border + else: + logger.debug("tiling and adding margin") + base_image = ( + source.crop( + ( + left + left_margin, + top + top_margin, + right - right_margin, + bottom - bottom_margin, + ) ) + if source + else None ) - if source - else None - ) - tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0)) - tile_image.paste(base_image, (left_margin, top_margin)) + tile_image = noise_source_histogram(base_image, (tile, tile), (0, 0)) + tile_image.paste(base_image, (left_margin, top_margin)) else: + logger.debug("tiling normally") tile_image = source.crop((left, top, right, bottom)) if source else None for image_filter in filters: @@ -288,7 +300,10 @@ def process_tile_spiral( tiles.append((left, top, tile_image)) - return blend_tiles(tiles, scale, width, height, tile, overlap) + if single_tile: + return tile_image + else: + return blend_tiles(tiles, scale, width, height, tile, overlap) def process_tile_order( @@ -326,17 +341,21 @@ def generate_tile_spiral( ) # dividing and then multiplying by 2 ensures this will be an even number, which is necessary for the initial tile placement calculation # calculate the number of tiles needed - width_tile_target = 1 - height_tile_target = 1 if width > tile: width_tile_target = 1 + ceil((width - tile) / tile_increment) + else: + width_tile_target = 1 if height > tile: height_tile_target = 1 + ceil((height - tile) / tile_increment) + else: + height_tile_target = 1 # calculate the start position of the tiling span_x = tile + (width_tile_target - 1) * tile_increment span_y = tile + (height_tile_target - 1) * tile_increment + logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y) + tile_left = ( width - span_x ) / 2 # guaranteed to be an integer because width and span will both be even