diff --git a/README.md b/README.md index cc386468..ed10a950 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,14 @@ with a CPU fallback capable of running on laptop-class machines. Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). -![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png) +![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png) ## Features This is an incomplete list of new and interesting features, with links to the user guide: +- SDXL support +- LCM support - hardware acceleration on both AMD and Nvidia - tested on CUDA, DirectML, and ROCm - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both @@ -37,6 +39,7 @@ This is an incomplete list of new and interesting features, with links to the us - [txt2img](docs/user-guide.md#txt2img-tab) - [img2img](docs/user-guide.md#img2img-tab) - [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload + - [panorama](docs/user-guide.md#panorama-pipeline), for both SD v1.5 and SDXL - [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration - [add and use your own models](docs/user-guide.md#adding-your-own-models) - [convert models from diffusers and SD checkpoints](docs/converting-models.md) @@ -45,20 +48,24 @@ This is an incomplete list of new and interesting features, with links to the us - [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks) - [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens) - [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens) + - each layer of the embeddings can be controlled and used individually - ControlNet - image filters for edge detection and other methods - with ONNX acceleration - highres mode - runs img2img on the results of the other pipelines - multiple iterations can produce 8k images and larger +- [multi-stage](docs/user-guide.md#prompt-stages) and [region prompts](docs/user-guide.md#region-tokens) + - seamlessly combine multiple prompts in the same image + - provide prompts for different areas in the image and blend them together + - change the prompt for highres mode and refine details without recursion - infinite prompt length - [with long prompt weighting](docs/user-guide.md#long-prompt-weighting) - - expand and control Textual Inversions per-layer - [image blending mode](docs/user-guide.md#blend-tab) - combine images from history -- upscaling and face correction - - upscaling with Real ESRGAN or Stable Diffusion - - face correction with CodeFormer or GFPGAN +- upscaling and correction + - upscaling with Real ESRGAN, SwinIR, and Stable Diffusion + - face correction with CodeFormer and GFPGAN - [API server can be run remotely](docs/server-admin.md) - REST API can be served over HTTPS or HTTP - background processing for all image pipelines @@ -66,7 +73,7 @@ This is an incomplete list of new and interesting features, with links to the us - OCI containers provided - for all supported hardware accelerators - includes both the API and GUI bundle in a single container - - runs well on [RunPod](https://www.runpod.io/) and other GPU container hosting services + - runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services ## Contents diff --git a/api/.vscode/settings.json b/api/.vscode/settings.json new file mode 100644 index 00000000..beeac074 --- /dev/null +++ b/api/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} \ No newline at end of file diff --git a/api/Makefile b/api/Makefile index 283c1ef1..579c8546 100644 --- a/api/Makefile +++ b/api/Makefile @@ -1,4 +1,4 @@ -.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload +.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload style onnx_env: ## create virtual env python -v venv onnx_env @@ -18,9 +18,10 @@ pip-dev: check-venv test: python -m coverage erase - python -m coverage run -m unittest discover -s tests/ + python -m coverage run -m unittest discover -v -s tests/ python -m coverage html -i python -m coverage xml -i + python -m coverage report -i package: package-dist package-upload @@ -32,13 +33,21 @@ package-upload: lint-check: black --check onnx_web/ - isort --check-only --skip __init__.py --filter-files onnx_web + black --check tests/ flake8 onnx_web + flake8 tests + isort --check-only --skip __init__.py --filter-files onnx_web + isort --check-only --skip __init__.py --filter-files tests lint-fix: black onnx_web/ - isort --skip __init__.py --filter-files onnx_web + black tests/ flake8 onnx_web + flake8 tests + isort --skip __init__.py --filter-files onnx_web + isort --skip __init__.py --filter-files tests + +style: lint-fix typecheck: mypy onnx_web diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index e0e23a30..f222ce73 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,45 +1,2 @@ -from .base import ChainPipeline, PipelineStage, StageParams -from .blend_img2img import BlendImg2ImgStage -from .blend_linear import BlendLinearStage -from .blend_mask import BlendMaskStage -from .correct_codeformer import CorrectCodeformerStage -from .correct_gfpgan import CorrectGFPGANStage -from .persist_disk import PersistDiskStage -from .persist_s3 import PersistS3Stage -from .reduce_crop import ReduceCropStage -from .reduce_thumbnail import ReduceThumbnailStage -from .source_noise import SourceNoiseStage -from .source_s3 import SourceS3Stage -from .source_txt2img import SourceTxt2ImgStage -from .source_url import SourceURLStage -from .upscale_bsrgan import UpscaleBSRGANStage -from .upscale_highres import UpscaleHighresStage -from .upscale_outpaint import UpscaleOutpaintStage -from .upscale_resrgan import UpscaleRealESRGANStage -from .upscale_simple import UpscaleSimpleStage -from .upscale_stable_diffusion import UpscaleStableDiffusionStage -from .upscale_swinir import UpscaleSwinIRStage - -CHAIN_STAGES = { - "blend-img2img": BlendImg2ImgStage, - "blend-inpaint": UpscaleOutpaintStage, - "blend-linear": BlendLinearStage, - "blend-mask": BlendMaskStage, - "correct-codeformer": CorrectCodeformerStage, - "correct-gfpgan": CorrectGFPGANStage, - "persist-disk": PersistDiskStage, - "persist-s3": PersistS3Stage, - "reduce-crop": ReduceCropStage, - "reduce-thumbnail": ReduceThumbnailStage, - "source-noise": SourceNoiseStage, - "source-s3": SourceS3Stage, - "source-txt2img": SourceTxt2ImgStage, - "source-url": SourceURLStage, - "upscale-bsrgan": UpscaleBSRGANStage, - "upscale-highres": UpscaleHighresStage, - "upscale-outpaint": UpscaleOutpaintStage, - "upscale-resrgan": UpscaleRealESRGANStage, - "upscale-simple": UpscaleSimpleStage, - "upscale-stable-diffusion": UpscaleStableDiffusionStage, - "upscale-swinir": UpscaleSwinIRStage, -} +from .pipeline import ChainPipeline, PipelineStage, StageParams +from .stages import * # NOQA diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index c7ebfe7f..02ad6c3a 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -1,240 +1,39 @@ -from datetime import timedelta -from logging import getLogger -from time import monotonic -from typing import Any, List, Optional, Tuple +from typing import Optional from PIL import Image -from ..errors import RetryException -from ..output import save_image -from ..params import ImageParams, StageParams -from ..server import ServerContext -from ..utils import is_debug, run_gc -from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage -from .tile import needs_tile, process_tile_order - -logger = getLogger(__name__) +from ..params import ImageParams, Size, SizeChart, StageParams +from ..server.context import ServerContext +from ..worker.context import WorkerContext +from .result import StageResult -PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] - - -class ChainProgress: - def __init__(self, parent: ProgressCallback, start=0) -> None: - self.parent = parent - self.step = start - self.total = 0 - - def __call__(self, step: int, timestep: int, latents: Any) -> None: - if step < self.step: - # accumulate on resets - self.total += self.step - - self.step = step - self.parent(self.get_total(), timestep, latents) - - def get_total(self) -> int: - return self.step + self.total - - @classmethod - def from_progress(cls, parent: ProgressCallback): - start = parent.step if hasattr(parent, "step") else 0 - return ChainProgress(parent, start=start) - - -class ChainPipeline: - """ - Run many stages in series, passing the image results from each to the next, and processing - tiles as needed. - """ - - def __init__( - self, - stages: Optional[List[PipelineStage]] = None, - ): - """ - Create a new pipeline that will run the given stages. - """ - self.stages = list(stages or []) - - def append(self, stage: Optional[PipelineStage]): - """ - Append an additional stage to this pipeline. - - This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to - assemble the stage from loose arguments. - """ - if stage is not None: - self.stages.append(stage) +class BaseStage: + max_tile = SizeChart.auto def run( self, - worker: WorkerContext, - server: ServerContext, - params: ImageParams, - sources: List[Image.Image], - callback: Optional[ProgressCallback], - **kwargs - ) -> List[Image.Image]: - return self( - worker, server, params, sources=sources, callback=callback, **kwargs - ) + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + _sources: StageResult, + *, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> StageResult: + raise NotImplementedError() # noqa - def stage(self, callback: BaseStage, params: StageParams, **kwargs): - self.stages.append((callback, params, kwargs)) - return self - - def __call__( + def steps( self, - worker: WorkerContext, - server: ServerContext, - params: ImageParams, - sources: List[Image.Image], - callback: Optional[ProgressCallback] = None, - **pipeline_kwargs - ) -> List[Image.Image]: - """ - DEPRECATED: use `run` instead - """ - if callback is not None: - callback = ChainProgress.from_progress(callback) + _params: ImageParams, + _size: Size, + ) -> int: + return 1 # noqa - start = monotonic() - - if len(sources) > 0: - logger.info( - "running pipeline on %s source images", - len(sources), - ) - else: - sources = [None] - logger.info("running pipeline without source images") - - stage_sources = sources - for stage_pipe, stage_params, stage_kwargs in self.stages: - name = stage_params.name or stage_pipe.__class__.__name__ - kwargs = stage_kwargs or {} - kwargs = {**pipeline_kwargs, **kwargs} - logger.debug( - "running stage %s with %s source images, parameters: %s", - name, - len(stage_sources) - stage_sources.count(None), - kwargs.keys(), - ) - - # the stage must be split and tiled if any image is larger than the selected/max tile size - must_tile = any( - [ - needs_tile( - stage_pipe.max_tile, - stage_params.tile_size, - size=kwargs.get("size", None), - source=source, - ) - for source in stage_sources - ] - ) - - tile = stage_params.tile_size - if stage_pipe.max_tile > 0: - tile = min(stage_pipe.max_tile, stage_params.tile_size) - - if must_tile: - stage_outputs = [] - for source in stage_sources: - logger.info( - "image larger than tile size of %s, tiling stage", - tile, - ) - - def stage_tile( - source_tile: Image.Image, - tile_mask: Image.Image, - dims: Tuple[int, int, int], - ) -> Image.Image: - for i in range(worker.retries): - try: - output_tile = stage_pipe.run( - worker, - server, - stage_params, - params, - [source_tile], - tile_mask=tile_mask, - callback=callback, - dims=dims, - **kwargs, - )[0] - - if is_debug(): - save_image(server, "last-tile.png", output_tile) - - return output_tile - except Exception: - logger.exception( - "error while running stage pipeline for tile, retry %s of 3", - i, - ) - server.cache.clear() - run_gc([worker.get_device()]) - worker.retries = worker.retries - (i + 1) - - raise RetryException("exhausted retries on tile") - - output = process_tile_order( - stage_params.tile_order, - source, - tile, - stage_params.outscale, - [stage_tile], - **kwargs, - ) - stage_outputs.append(output) - - stage_sources = stage_outputs - else: - logger.debug("image within tile size of %s, running stage", tile) - for i in range(worker.retries): - try: - stage_outputs = stage_pipe.run( - worker, - server, - stage_params, - params, - stage_sources, - callback=callback, - **kwargs, - ) - # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline - # does not like, so it throws - stage_sources = stage_outputs - break - except Exception: - logger.exception( - "error while running stage pipeline, retry %s of 3", i - ) - server.cache.clear() - run_gc([worker.get_device()]) - worker.retries = worker.retries - (i + 1) - - if worker.retries <= 0: - raise RetryException("exhausted retries on stage") - - logger.debug( - "finished stage %s with %s results", - name, - len(stage_sources), - ) - - if is_debug(): - save_image(server, "last-stage.png", stage_sources[0]) - - end = monotonic() - duration = timedelta(seconds=(end - start)) - logger.info( - "finished pipeline in %s with %s results", - duration, - len(stage_sources), - ) - return stage_sources + def outputs( + self, + _params: ImageParams, + sources: int, + ) -> int: + return sources diff --git a/api/onnx_web/chain/blend_denoise.py b/api/onnx_web/chain/blend_denoise.py new file mode 100644 index 00000000..e40a30a2 --- /dev/null +++ b/api/onnx_web/chain/blend_denoise.py @@ -0,0 +1,40 @@ +from logging import getLogger +from typing import Optional + +import cv2 +from PIL import Image + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class BlendDenoiseStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + sources: StageResult, + *, + strength: int = 3, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + logger.info("denoising source images") + + results = [] + for source in sources.as_numpy(): + data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) + data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) + results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)) + + return StageResult(arrays=results) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py new file mode 100644 index 00000000..34e4f535 --- /dev/null +++ b/api/onnx_web/chain/blend_grid.py @@ -0,0 +1,62 @@ +from logging import getLogger +from typing import List, Optional + +from PIL import Image + +from ..params import ImageParams, Size, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class BlendGridStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + sources: StageResult, + *, + height: int, + width: int, + # rows: Optional[List[str]] = None, + # columns: Optional[List[str]] = None, + # title: Optional[str] = None, + order: Optional[List[int]] = None, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + logger.info("combining source images using grid layout") + + images = sources.as_image() + ref_image = images[0] + size = Size(*ref_image.size) + + output = Image.new(ref_image.mode, (size.width * width, size.height * height)) + + # TODO: labels + if order is None: + order = range(len(images)) + + for i in range(len(order)): + x = i % width + y = i // width + + n = order[i] + output.paste(images[n], (x * size.width, y * size.height)) + + return StageResult(images=[*images, output]) + + def outputs( + self, + _params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index e3c249a9..4528946f 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional import numpy as np import torch @@ -10,13 +10,14 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class BlendImg2ImgStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, @@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage): server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, strength: float, callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None, prompt_index: Optional[int] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: params = params.with_args(**kwargs) # multi-stage prompting @@ -52,7 +53,7 @@ class BlendImg2ImgStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) @@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage): pipe_params["strength"] = strength outputs = [] - for source in sources: + for source in sources.as_image(): if params.is_lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) @@ -81,11 +82,10 @@ class BlendImg2ImgStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) @@ -102,4 +102,18 @@ class BlendImg2ImgStage(BaseStage): outputs.extend(result.images) - return outputs + return StageResult(images=outputs) + + def steps( + self, + params: ImageParams, + *args, + ) -> int: + return params.steps # TODO: multiply by strength + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 6317ef13..1b40a5fd 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -1,12 +1,13 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,13 +19,18 @@ class BlendLinearStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, alpha: float, stage_source: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("blending source images using linear interpolation") - return [Image.blend(source, stage_source, alpha) for source in sources] + return StageResult( + images=[ + Image.blend(source, stage_source, alpha) + for source in sources.as_image() + ] + ) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 1038d3ea..4486bbf6 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image @@ -8,7 +8,8 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,16 +21,17 @@ class BlendMaskStage(BaseStage): server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("blending image using mask") - mult_mask = Image.new("RGBA", stage_mask.size, color="black") + # TODO: does this need an alpha channel? + mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") mult_mask.alpha_composite(stage_mask) mult_mask = mult_mask.convert("L") @@ -37,4 +39,9 @@ class BlendMaskStage(BaseStage): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return [Image.composite(stage_source, source, mult_mask) for source in sources] + return StageResult( + images=[ + Image.composite(stage_source, source, mult_mask) + for source in sources.as_image() + ] + ) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 121a5cb3..1169d4fb 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,12 +1,13 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: # must be within the load function for patch to take effect # TODO: rewrite and remove from codeformer import CodeFormer @@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage): device = worker.get_device() pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return [pipe(source) for source in sources] + return StageResult(images=[pipe(source) for source in sources.as_image()]) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 145ff36b..e1db8bcb 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -1,15 +1,15 @@ from logging import getLogger from os import path -from typing import List, Optional +from typing import Optional -import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -57,12 +57,12 @@ class CorrectGFPGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.correction_model is None: @@ -73,16 +73,15 @@ class CorrectGFPGANStage(BaseStage): device = worker.get_device() gfpgan = self.load(server, stage, upscale, device) - outputs = [] - for source in sources: - output = np.array(source) - _, _, output = gfpgan.enhance( - output, + outputs = [ + gfpgan.enhance( + source, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength, ) - outputs.append(Image.fromarray(output, "RGB")) + for source in sources.as_numpy() + ] - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 87b52d9b..2a43e051 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -1,11 +1,11 @@ from logging import getLogger from typing import Optional -from ..chain.base import ChainPipeline from ..chain.blend_img2img import BlendImg2ImgStage from ..chain.upscale import stage_upscale_correction from ..chain.upscale_simple import UpscaleSimpleStage from ..params import HighresParams, ImageParams, StageParams, UpscaleParams +from .pipeline import ChainPipeline logger = getLogger(__name__) @@ -43,7 +43,7 @@ def stage_highres( outscale=highres.scale, ), chain=chain, - overlap=params.overlap, + overlap=params.vae_overlap, ) else: logger.debug("using simple upscaling for highres") @@ -51,14 +51,14 @@ def stage_highres( UpscaleSimpleStage(), stage, method=highres.method, - overlap=params.overlap, + overlap=params.vae_overlap, upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), ) chain.stage( BlendImg2ImgStage(), - stage, - overlap=params.overlap, + stage.with_args(outscale=1), + overlap=params.vae_overlap, prompt_index=prompt_index + i, strength=highres.strength, ) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 794dbb37..28a08848 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,33 +1,38 @@ from logging import getLogger -from typing import List +from typing import List, Optional from PIL import Image from ..output import save_image -from ..params import ImageParams, StageParams +from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class PersistDiskStage(BaseStage): + max_tile = SizeChart.max + def run( self, _worker: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, - output: str, - stage_source: Image.Image, + output: List[str], + size: Optional[Size] = None, + stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: - for source in sources: - # TODO: append index to output name - dest = save_image(server, output, source, params=params) + ) -> StageResult: + logger.info("persisting %s images to disk: %s", len(sources), output) + + for source, name in zip(sources.as_image(), output): + dest = save_image(server, name, source, params=params, size=size) logger.info("saved image to %s", dest) return sources diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 27f4026f..060afc4f 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -8,7 +8,8 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,26 +21,26 @@ class PersistS3Stage(BaseStage): server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, - output: str, + output: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - for source in sources: + for source, name in zip(sources.as_image(), output): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) try: - s3.upload_fileobj(data, bucket, output) - logger.info("saved image to s3://%s/%s", bucket, output) + s3.upload_fileobj(data, bucket, name) + logger.info("saved image to s3://%s/%s", bucket, name) except Exception: logger.exception("error saving image to S3") diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py new file mode 100644 index 00000000..ff3fae81 --- /dev/null +++ b/api/onnx_web/chain/pipeline.py @@ -0,0 +1,281 @@ +from datetime import timedelta +from logging import getLogger +from time import monotonic +from typing import Any, List, Optional, Tuple + +from PIL import Image + +from ..errors import CancelledException, RetryException +from ..output import save_image +from ..params import ImageParams, Size, StageParams +from ..server import ServerContext +from ..utils import is_debug, run_gc +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult +from .tile import needs_tile, process_tile_order + +logger = getLogger(__name__) + + +PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] + + +class ChainProgress: + def __init__(self, parent: ProgressCallback, start=0) -> None: + self.parent = parent + self.step = start + self.total = 0 + + def __call__(self, step: int, timestep: int, latents: Any) -> None: + if step < self.step: + # accumulate on resets + self.total += self.step + + self.step = step + self.parent(self.get_total(), timestep, latents) + + def get_total(self) -> int: + return self.step + self.total + + @classmethod + def from_progress(cls, parent: ProgressCallback): + start = parent.step if hasattr(parent, "step") else 0 + return ChainProgress(parent, start=start) + + +class ChainPipeline: + """ + Run many stages in series, passing the image results from each to the next, and processing + tiles as needed. + """ + + def __init__( + self, + stages: Optional[List[PipelineStage]] = None, + ): + """ + Create a new pipeline that will run the given stages. + """ + self.stages = list(stages or []) + + def append(self, stage: Optional[PipelineStage]): + """ + Append an additional stage to this pipeline. + + This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to + assemble the stage from loose arguments. + """ + if stage is not None: + self.stages.append(stage) + + def run( + self, + worker: WorkerContext, + server: ServerContext, + params: ImageParams, + sources: StageResult, + callback: Optional[ProgressCallback], + **kwargs, + ) -> List[Image.Image]: + result = self( + worker, server, params, sources=sources, callback=callback, **kwargs + ) + return result.as_image() + + def stage(self, callback: BaseStage, params: StageParams, **kwargs): + self.stages.append((callback, params, kwargs)) + return self + + def steps(self, params: ImageParams, size: Size) -> int: + steps = 0 + for callback, _params, kwargs in self.stages: + steps += callback.steps(kwargs.get("params", params), size) + + return steps + + def outputs(self, params: ImageParams, sources: int) -> int: + outputs = sources + for callback, _params, kwargs in self.stages: + outputs = callback.outputs(kwargs.get("params", params), outputs) + + return outputs + + def __call__( + self, + worker: WorkerContext, + server: ServerContext, + params: ImageParams, + sources: StageResult, + callback: Optional[ProgressCallback] = None, + **pipeline_kwargs, + ) -> StageResult: + """ + DEPRECATED: use `.run()` instead + """ + if callback is None: + callback = worker.get_progress_callback() + else: + callback = ChainProgress.from_progress(callback) + + start = monotonic() + + if len(sources) > 0: + logger.info( + "running pipeline on %s source images", + len(sources), + ) + else: + logger.info("running pipeline without source images") + + stage_sources = sources + for stage_pipe, stage_params, stage_kwargs in self.stages: + name = stage_params.name or stage_pipe.__class__.__name__ + kwargs = stage_kwargs or {} + kwargs = {**pipeline_kwargs, **kwargs} + logger.debug( + "running stage %s with %s source images, parameters: %s", + name, + len(stage_sources), + kwargs.keys(), + ) + + per_stage_params = params + if "params" in kwargs: + per_stage_params = kwargs["params"] + kwargs.pop("params") + + # the stage must be split and tiled if any image is larger than the selected/max tile size + must_tile = has_mask(stage_kwargs) or any( + [ + needs_tile( + stage_pipe.max_tile, + stage_params.tile_size, + size=kwargs.get("size", None), + source=source, + ) + for source in stage_sources.as_image() + ] + ) + + tile = stage_params.tile_size + if stage_pipe.max_tile > 0: + tile = min(stage_pipe.max_tile, stage_params.tile_size) + + if must_tile: + logger.info( + "image contains sources or is larger than tile size of %s, tiling stage", + tile, + ) + + def stage_tile( + source_tile: List[Image.Image], + tile_mask: Image.Image, + dims: Tuple[int, int, int], + ) -> List[Image.Image]: + for _i in range(worker.retries): + try: + tile_result = stage_pipe.run( + worker, + server, + stage_params, + per_stage_params, + StageResult(images=source_tile), + tile_mask=tile_mask, + callback=callback, + dims=dims, + **kwargs, + ) + + if is_debug(): + for j, image in enumerate(tile_result.as_image()): + save_image(server, f"last-tile-{j}.png", image) + + return tile_result + except CancelledException as err: + worker.retries = 0 + logger.exception("job was cancelled while tiling") + raise err + except Exception: + worker.retries = worker.retries - 1 + logger.exception( + "error while running stage pipeline for tile, %s retries left", + worker.retries, + ) + server.cache.clear() + run_gc([worker.get_device()]) + + raise RetryException("exhausted retries on tile") + + stage_results = process_tile_order( + stage_params.tile_order, + stage_sources, + tile, + stage_params.outscale, + [stage_tile], + **kwargs, + ) + + stage_sources = StageResult(images=stage_results) + else: + logger.debug( + "image does not contain sources and is within tile size of %s, running stage", + tile, + ) + for _i in range(worker.retries): + try: + stage_result = stage_pipe.run( + worker, + server, + stage_params, + per_stage_params, + stage_sources, + callback=callback, + dims=(0, 0, tile), + **kwargs, + ) + # doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline + # does not like, so it throws + stage_sources = stage_result + break + except CancelledException as err: + worker.retries = 0 + logger.exception("job was cancelled during stage") + raise err + except Exception: + worker.retries = worker.retries - 1 + logger.exception( + "error while running stage pipeline, %s retries left", + worker.retries, + ) + server.cache.clear() + run_gc([worker.get_device()]) + + if worker.retries <= 0: + raise RetryException("exhausted retries on stage") + + logger.debug( + "finished stage %s with %s results", + name, + len(stage_sources), + ) + + if is_debug(): + for j, image in enumerate(stage_sources.as_image()): + save_image(server, f"last-stage-{j}.png", image) + + end = monotonic() + duration = timedelta(seconds=(end - start)) + logger.info( + "finished pipeline in %s with %s results", + duration, + len(stage_sources), + ) + return stage_sources + + +MASK_KEYS = ["mask", "stage_mask", "tile_mask"] + + +def has_mask(args: List[str]) -> bool: + return any([key in args for key in MASK_KEYS]) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 2e258075..fe98fbd3 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -1,12 +1,13 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, origin: Size, size: Size, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: outputs = [] - for source in sources: + for source in sources.as_image(): image = source.crop((origin.width, origin.height, size.width, size.height)) logger.info( "created thumbnail with dimensions: %sx%s", image.width, image.height ) outputs.append(image) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index d7a0efee..9c65a819 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -1,12 +1,12 @@ from logging import getLogger -from typing import List from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,15 +18,15 @@ class ReduceThumbnailStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, size: Size, stage_source: Image.Image, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: outputs = [] - for source in sources: + for source in sources.as_image(): image = source.copy() image = image.thumbnail((size.width, size.height)) @@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage): outputs.append(image) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py new file mode 100644 index 00000000..813f5863 --- /dev/null +++ b/api/onnx_web/chain/result.py @@ -0,0 +1,73 @@ +from typing import List, Optional + +import numpy as np +from PIL import Image + + +class StageResult: + """ + Chain pipeline stage result. + Can contain PIL images or numpy arrays, with helpers to convert between them. + This class intentionally does not provide `__iter__`, to ensure clients get results in the format + they are expected. + """ + + arrays: Optional[List[np.ndarray]] + images: Optional[List[Image.Image]] + + @staticmethod + def empty(): + return StageResult(images=[]) + + @staticmethod + def from_arrays(arrays: List[np.ndarray]): + return StageResult(arrays=arrays) + + @staticmethod + def from_images(images: List[Image.Image]): + return StageResult(images=images) + + def __init__(self, arrays=None, images=None) -> None: + if arrays is not None and images is not None: + raise ValueError("stages must only return one type of result") + elif arrays is None and images is None: + raise ValueError("stages must return results") + + self.arrays = arrays + self.images = images + + def __len__(self) -> int: + if self.arrays is not None: + return len(self.arrays) + elif self.images is not None: + return len(self.images) + else: + return 0 + + def as_numpy(self) -> List[np.ndarray]: + if self.arrays is not None: + return self.arrays + elif self.images is not None: + return [np.array(i) for i in self.images] + else: + return [] + + def as_image(self) -> List[Image.Image]: + if self.images is not None: + return self.images + elif self.arrays is not None: + return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays] + else: + return [] + + +def shape_mode(arr: np.ndarray) -> str: + if len(arr.shape) != 3: + raise ValueError("unknown array format") + + if arr.shape[-1] == 3: + return "RGB" + elif arr.shape[-1] == 4: + return "RGBA" + + raise ValueError("unknown image format") diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 1ee68f42..d1b2eac2 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -1,12 +1,13 @@ from logging import getLogger -from typing import Callable, List +from typing import Callable, Optional from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,25 +19,34 @@ class SourceNoiseStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, size: Size, noise_source: Callable, - stage_source: Image.Image, + stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("generating image from noise source") if len(sources) > 0: - logger.warning( - "source images were passed to a noise stage and will be discarded" + logger.info( + "source images were passed to a source stage, new images will be appended" ) outputs = [] - for source in sources: + + # TODO: looping over sources and ignoring params does not make much sense for a source stage + for source in sources.as_image(): output = noise_source(source, (size.width, size.height), (0, 0)) logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return outputs + return StageResult(images=outputs) + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 900270a3..d9a53aca 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -8,7 +8,8 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,18 +21,23 @@ class SourceS3Stage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - _sources: List[Image.Image], + sources: StageResult, *, source_keys: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - outputs = [] + if len(sources) > 0: + logger.info( + "source images were passed to a source stage, new images will be appended" + ) + + outputs = sources.as_image() for key in source_keys: try: logger.info("loading image from s3://%s/%s", bucket, key) @@ -43,4 +49,11 @@ class SourceS3Stage(BaseStage): except Exception: logger.exception("error loading image from S3") - return outputs + return StageResult(outputs) + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 # TODO: len(source_keys) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index cc642d55..571e58ad 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -3,26 +3,28 @@ from typing import Optional, Tuple import numpy as np import torch -from PIL import Image +from ..constants import LATENT_FACTOR from ..diffusers.load import load_pipeline from ..diffusers.utils import ( encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt, + parse_reseed, slice_prompt, ) from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class SourceTxt2ImgStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, @@ -30,15 +32,15 @@ class SourceTxt2ImgStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - _source: Image.Image, + sources: StageResult, *, - dims: Tuple[int, int, int], + dims: Tuple[int, int, int] = None, size: Size, callback: Optional[ProgressCallback] = None, latents: Optional[np.ndarray] = None, prompt_index: Optional[int] = None, **kwargs, - ) -> Image.Image: + ) -> StageResult: params = params.with_args(**kwargs) size = size.with_args(**kwargs) @@ -47,31 +49,58 @@ class SourceTxt2ImgStage(BaseStage): params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index)) logger.info( - "generating image using txt2img, %s steps: %s", params.steps, params.prompt + "generating image using txt2img, %s steps of %s: %s", + params.steps, + params.model, + params.prompt, ) - if "stage_source" in kwargs: - logger.warning( - "a source image was passed to a txt2img stage, and will be discarded" + if len(sources): + logger.info( + "source images were passed to a source stage, new images will be appended" ) prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( params ) - if params.is_xl(): - tile_size = max(stage.tile_size, params.tiles) + if params.is_panorama() or params.is_xl(): + tile_size = max(stage.tile_size, params.unet_tile) else: - tile_size = params.tiles + tile_size = params.unet_tile # this works for panorama as well, because tile_size is already max(tile_size, *size) latent_size = size.min(tile_size, tile_size) # generate new latents or slice existing if latents is None: - latents = get_latents_from_seed(params.seed, latent_size, params.batch) + latents = get_latents_from_seed(int(params.seed), latent_size, params.batch) else: - latents = get_tile_latents(latents, params.seed, latent_size, dims) + latents = get_tile_latents(latents, int(params.seed), latent_size, dims) + + # reseed latents as needed + reseed_rng = np.random.RandomState(params.seed) + prompt, reseed = parse_reseed(prompt) + for top, left, bottom, right, region_seed in reseed: + if region_seed == -1: + region_seed = reseed_rng.random_integers(2**32 - 1) + + logger.debug( + "reseed latent region: [:, :, %s:%s, %s:%s] with %s", + top, + left, + bottom, + right, + region_seed, + ) + latents[ + :, + :, + top // LATENT_FACTOR : bottom // LATENT_FACTOR, + left // LATENT_FACTOR : right // LATENT_FACTOR, + ] = get_latents_from_seed( + region_seed, Size(right - left, bottom - top), params.batch + ) pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( @@ -79,7 +108,7 @@ class SourceTxt2ImgStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) @@ -101,11 +130,14 @@ class SourceTxt2ImgStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - - if not params.is_xl(): + if params.is_panorama() or params.is_xl(): + logger.debug( + "prompt alternatives are not supported for panorama or SDXL" + ) + else: + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) @@ -123,4 +155,21 @@ class SourceTxt2ImgStage(BaseStage): callback=callback, ) - return result.images + outputs = sources.as_image() + outputs.extend(result.images) + logger.debug("produced %s outputs", len(outputs)) + return StageResult(images=outputs) + + def steps( + self, + params: ImageParams, + size: Size, + ) -> int: + return params.steps + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 5fa54b67..b6aa62cd 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -1,6 +1,6 @@ from io import BytesIO from logging import getLogger -from typing import List +from typing import List, Optional import requests from PIL import Image @@ -8,7 +8,8 @@ from PIL import Image from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,20 +21,20 @@ class SourceURLStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, source_urls: List[str], - stage_source: Image.Image, + stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("loading image from URL source") if len(sources) > 0: - logger.warning( - "a source image was passed to a source stage, and will be discarded" + logger.info( + "source images were passed to a source stage, new images will be appended" ) - outputs = [] + outputs = sources.as_image() for url in source_urls: response = requests.get(url) output = Image.open(BytesIO(response.content)) @@ -41,4 +42,11 @@ class SourceURLStage(BaseStage): logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return outputs + return StageResult(images=outputs) + + def outputs( + self, + params: ImageParams, + sources: int, + ) -> int: + return sources + 1 diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py deleted file mode 100644 index 781b65de..00000000 --- a/api/onnx_web/chain/stage.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List, Optional - -from PIL import Image - -from ..params import ImageParams, Size, SizeChart, StageParams -from ..server.context import ServerContext -from ..worker.context import WorkerContext - - -class BaseStage: - max_tile = SizeChart.auto - - def run( - self, - worker: WorkerContext, - server: ServerContext, - stage: StageParams, - _params: ImageParams, - sources: List[Image.Image], - *args, - stage_source: Optional[Image.Image] = None, - **kwargs, - ) -> List[Image.Image]: - raise NotImplementedError() - - def steps( - self, - _params: ImageParams, - size: Size, - ) -> int: - raise NotImplementedError() diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py new file mode 100644 index 00000000..4ae14346 --- /dev/null +++ b/api/onnx_web/chain/stages.py @@ -0,0 +1,64 @@ +from logging import getLogger + +from .base import BaseStage +from .blend_denoise import BlendDenoiseStage +from .blend_grid import BlendGridStage +from .blend_img2img import BlendImg2ImgStage +from .blend_linear import BlendLinearStage +from .blend_mask import BlendMaskStage +from .correct_codeformer import CorrectCodeformerStage +from .correct_gfpgan import CorrectGFPGANStage +from .persist_disk import PersistDiskStage +from .persist_s3 import PersistS3Stage +from .reduce_crop import ReduceCropStage +from .reduce_thumbnail import ReduceThumbnailStage +from .source_noise import SourceNoiseStage +from .source_s3 import SourceS3Stage +from .source_txt2img import SourceTxt2ImgStage +from .source_url import SourceURLStage +from .upscale_bsrgan import UpscaleBSRGANStage +from .upscale_highres import UpscaleHighresStage +from .upscale_outpaint import UpscaleOutpaintStage +from .upscale_resrgan import UpscaleRealESRGANStage +from .upscale_simple import UpscaleSimpleStage +from .upscale_stable_diffusion import UpscaleStableDiffusionStage +from .upscale_swinir import UpscaleSwinIRStage + +logger = getLogger(__name__) + +CHAIN_STAGES = { + "blend-denoise": BlendDenoiseStage, + "blend-img2img": BlendImg2ImgStage, + "blend-inpaint": UpscaleOutpaintStage, + "blend-grid": BlendGridStage, + "blend-linear": BlendLinearStage, + "blend-mask": BlendMaskStage, + "correct-codeformer": CorrectCodeformerStage, + "correct-gfpgan": CorrectGFPGANStage, + "persist-disk": PersistDiskStage, + "persist-s3": PersistS3Stage, + "reduce-crop": ReduceCropStage, + "reduce-thumbnail": ReduceThumbnailStage, + "source-noise": SourceNoiseStage, + "source-s3": SourceS3Stage, + "source-txt2img": SourceTxt2ImgStage, + "source-url": SourceURLStage, + "upscale-bsrgan": UpscaleBSRGANStage, + "upscale-highres": UpscaleHighresStage, + "upscale-outpaint": UpscaleOutpaintStage, + "upscale-resrgan": UpscaleRealESRGANStage, + "upscale-simple": UpscaleSimpleStage, + "upscale-stable-diffusion": UpscaleStableDiffusionStage, + "upscale-swinir": UpscaleSwinIRStage, +} + + +def add_stage(name: str, stage: BaseStage) -> bool: + global CHAIN_STAGES + + if name in CHAIN_STAGES: + logger.warning("cannot replace stage: %s", name) + return False + else: + CHAIN_STAGES[name] = stage + return True diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 8b7898c6..01ea9f2b 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -2,13 +2,14 @@ import itertools from enum import Enum from logging import getLogger from math import ceil -from typing import List, Optional, Protocol, Tuple +from typing import Any, Callable, List, Optional, Protocol, Tuple, Union import numpy as np from PIL import Image from ..image.noise_source import noise_source_histogram from ..params import Size, TileOrder +from .result import StageResult # from skimage.exposure import match_histograms @@ -16,12 +17,15 @@ from ..params import Size, TileOrder logger = getLogger(__name__) +TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]] + + class TileCallback(Protocol): """ Definition for a tile job function. """ - def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: + def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult: """ Run this stage against a single tile. """ @@ -32,6 +36,9 @@ def complete_tile( source: Image.Image, tile: int, ) -> Image.Image: + """ + TODO: clean up + """ if source is None: return source @@ -50,6 +57,12 @@ def needs_tile( source: Optional[Image.Image] = None, ) -> bool: tile = min(max_tile, stage_tile) + logger.trace( + "checking image tile dimensions: %s, %s, %s", + tile, + source.width > tile or source.height > tile if source is not None else False, + size.width > tile or size.height > tile if size is not None else False, + ) if source is not None: return source.width > tile or source.height > tile @@ -60,7 +73,7 @@ def needs_tile( return False -def get_tile_grads( +def make_tile_grads( left: int, top: int, tile: int, @@ -85,6 +98,60 @@ def get_tile_grads( return (grad_x, grad_y) +def make_tile_mask( + shape: Any, + tile: Tuple[int, int], + overlap: float, + edges: Tuple[bool, bool, bool, bool], +) -> np.ndarray: + mask = np.ones(shape) + + tile_h, tile_w = tile + + adj_tile_h = int(float(tile_h) * (1.0 - overlap)) + adj_tile_w = int(float(tile_w) * (1.0 - overlap)) + + # sort gradient points + p1_h = adj_tile_h - 1 + p2_h = tile_h - adj_tile_h + points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h] + + p1_w = adj_tile_w - 1 + p2_w = tile_w - adj_tile_w + points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w] + + # build gradients + edge_t, edge_l, edge_b, edge_r = edges + grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [ + int(not edge_t), + 1, + 1, + int(not edge_b), + ] + logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y) + + mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)] + mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)] + + mask = ((mask * mult_x).T * mult_y).T + + return mask + + +def get_channels(image: Union[np.ndarray, Image.Image]) -> int: + if isinstance(image, np.ndarray): + return image.shape[-1] + + if image.mode == "RGBA": + return 4 + elif image.mode == "RGB": + return 3 + elif image.mode == "L": + return 1 + + raise ValueError("unknown image format") + + def blend_tiles( tiles: List[Tuple[int, int, Image.Image]], scale: int, @@ -98,23 +165,24 @@ def blend_tiles( "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap ) - scaled_size = (height * scale, width * scale, 3) + channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles]) + scaled_size = (height * scale, width * scale, channels) + count = np.zeros(scaled_size) value = np.zeros(scaled_size) for left, top, tile_image in tiles: - # histogram equalization equalized = np.array(tile_image).astype(np.float32) mask = np.ones_like(equalized[:, :, 0]) if adj_tile < tile: # sort gradient points - p1 = adj_tile * scale - p2 = (tile - adj_tile) * scale - points = [0, min(p1, p2), max(p1, p2), tile * scale] + p1 = (adj_tile * scale) - 1 + p2 = (tile - adj_tile - 1) * scale + points = [-1, min(p1, p2), max(p1, p2), (tile * scale)] # gradient blending - grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height) + grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height) logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] @@ -169,7 +237,7 @@ def blend_tiles( margin_left : equalized.shape[1] + margin_right, np.newaxis, ], - 3, + channels, axis=2, ) @@ -178,60 +246,18 @@ def blend_tiles( return Image.fromarray(np.uint8(pixels)) -def process_tile_grid( - source: Image.Image, - tile: int, - scale: int, - filters: List[TileCallback], - overlap: float = 0.0, - **kwargs, -) -> Image.Image: - width, height = kwargs.get("size", source.size if source else None) - - adj_tile = int(float(tile) * (1.0 - overlap)) - tiles_x = ceil(width / adj_tile) - tiles_y = ceil(height / adj_tile) - total = tiles_x * tiles_y - logger.debug( - "processing %s tiles (%s x %s) with adjusted size of %s, %s overlap", - total, - tiles_x, - tiles_y, - adj_tile, - overlap, - ) - - tiles: List[Tuple[int, int, Image.Image]] = [] - - for y in range(tiles_y): - for x in range(tiles_x): - idx = (y * tiles_x) + x - left = x * adj_tile - top = y * adj_tile - logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x) - - tile_image = ( - source.crop((left, top, left + tile, top + tile)) if source else None - ) - tile_image = complete_tile(tile_image, tile) - - for filter in filters: - tile_image = filter(tile_image, (left, top, tile)) - - tiles.append((left, top, tile_image)) - - return blend_tiles(tiles, scale, width, height, tile, overlap) - - -def process_tile_spiral( - source: Image.Image, +def process_tile_stack( + stack: StageResult, tile: int, scale: int, filters: List[TileCallback], + tile_generator: TileGenerator, overlap: float = 0.5, **kwargs, -) -> Image.Image: - width, height = kwargs.get("size", source.size if source else None) +) -> List[Image.Image]: + sources = stack.as_image() + + width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) mask = kwargs.get("mask", None) noise_source = kwargs.get("noise_source", noise_source_histogram) fill_color = kwargs.get("fill_color", None) @@ -239,18 +265,10 @@ def process_tile_spiral( tile_mask = None tiles: List[Tuple[int, int, Image.Image]] = [] + tile_coords = tile_generator(width, height, tile, overlap) + single_tile = len(tile_coords) == 1 - # tile tuples is source, multiply by scale for dest - counter = 0 - tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap) - - if len(tile_coords) == 1: - single_tile = True - else: - single_tile = False - - for left, top in tile_coords: - counter += 1 + for counter, (left, top) in enumerate(tile_coords): logger.info( "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top ) @@ -274,26 +292,36 @@ def process_tile_spiral( needs_margin = True bottom_margin = height - bottom - # if no source given, we don't have a source image - if not source: - tile_image = None + if single_tile: + logger.debug("using single tile") + tile_stack = sources + if mask: + tile_mask = mask elif needs_margin: - # in the special case where the image is smaller than the specified tile size, just use the image - if single_tile: - logger.debug("creating and processing single-tile subtile") - tile_image = source - if mask: - tile_mask = mask - # otherwise use add histogram noise outside of the image border - else: - logger.debug( - "tiling and adding margins: %s, %s, %s, %s", - left_margin, - top_margin, - right_margin, - bottom_margin, - ) - base_image = source.crop( + logger.debug( + "tiling with added margins: %s, %s, %s, %s", + left_margin, + top_margin, + right_margin, + bottom_margin, + ) + tile_stack = add_margin( + stack.as_image(), + left, + top, + right, + bottom, + left_margin, + top_margin, + right_margin, + bottom_margin, + tile, + noise_source, + fill_color, + ) + + if mask: + base_mask = mask.crop( ( left + left_margin, top + top_margin, @@ -301,57 +329,60 @@ def process_tile_spiral( bottom + bottom_margin, ) ) - tile_image = noise_source( - base_image, (tile, tile), (0, 0), fill=fill_color - ) - tile_image.paste(base_image, (left_margin, top_margin)) - - if mask: - base_mask = mask.crop( - ( - left + left_margin, - top + top_margin, - right + right_margin, - bottom + bottom_margin, - ) - ) - tile_mask = Image.new("L", (tile, tile), color=0) - tile_mask.paste(base_mask, (left_margin, top_margin)) + tile_mask = Image.new("L", (tile, tile), color=0) + tile_mask.paste(base_mask, (left_margin, top_margin)) else: logger.debug("tiling normally") - tile_image = source.crop((left, top, right, bottom)) + tile_stack = get_result_tile(stack, (left, top), Size(tile, tile)) if mask: tile_mask = mask.crop((left, top, right, bottom)) for image_filter in filters: - tile_image = image_filter(tile_image, tile_mask, (left, top, tile)) + tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) - tiles.append((left, top, tile_image)) + if isinstance(tile_stack, list): + tile_stack = StageResult.from_images(tile_stack) - if single_tile: - return tile_image - else: - return blend_tiles(tiles, scale, width, height, tile, overlap) + tiles.append((left, top, tile_stack.as_image())) + + lefts, tops, stacks = list(zip(*tiles)) + coords = list(zip(lefts, tops)) + stacks = list(zip(*stacks)) + + result = [] + for stack in stacks: + stack_tiles = zip(coords, stack) + stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles] + result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap)) + + return result def process_tile_order( order: TileOrder, - source: Image.Image, + stack: StageResult, tile: int, scale: int, filters: List[TileCallback], **kwargs, ) -> Image.Image: + """ + TODO: needs to handle more than one image + """ if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) - return process_tile_grid(source, tile, scale, filters, **kwargs) + return process_tile_stack( + stack, tile, scale, filters, generate_tile_grid, **kwargs + ) elif order == TileOrder.kernel: logger.debug("using kernel tile order with tile size: %s", tile) raise NotImplementedError() elif order == TileOrder.spiral: logger.debug("using spiral tile order with tile size: %s", tile) - return process_tile_spiral(source, tile, scale, filters, **kwargs) + return process_tile_stack( + stack, tile, scale, filters, generate_tile_spiral, **kwargs + ) else: logger.warning("unknown tile order: %s", order) raise ValueError() @@ -445,3 +476,77 @@ def generate_tile_spiral( height_tile_target -= abs(state.value[1]) return tile_coords + + +def generate_tile_grid( + width: int, + height: int, + tile: int, + overlap: float = 0.0, +) -> List[Tuple[int, int]]: + adj_tile = int(float(tile) * (1.0 - overlap)) + tiles_x = ceil(width / adj_tile) + tiles_y = ceil(height / adj_tile) + total = tiles_x * tiles_y + logger.debug( + "processing %s tiles (%s x %s) with adjusted size of %s, %s overlap", + total, + tiles_x, + tiles_y, + adj_tile, + overlap, + ) + + tiles: List[Tuple[int, int, Image.Image]] = [] + + for y in range(tiles_y): + for x in range(tiles_x): + left = x * adj_tile + top = y * adj_tile + + tiles.append((int(left), int(top))) + + return tiles + + +def get_result_tile( + result: StageResult, + origin: Tuple[int, int], + tile: Size, +) -> List[Image.Image]: + top, left = origin + return [ + layer.crop((top, left, top + tile.height, left + tile.width)) + for layer in result.as_image() + ] + + +def add_margin( + stack: List[Image.Image], + left: int, + top: int, + right: int, + bottom: int, + left_margin: int, + top_margin: int, + right_margin: int, + bottom_margin: int, + tile: int, + noise_source, + fill_color, +) -> List[Image.Image]: + results = [] + for source in stack: + base_image = source.crop( + ( + left + left_margin, + top + top_margin, + right + right_margin, + bottom + bottom_margin, + ) + ) + tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color) + tile_image.paste(base_image, (left_margin, top_margin)) + results.append(tile_image) + + return results diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index f9f02f1e..08c07759 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -1,22 +1,30 @@ from logging import getLogger from os import path -from typing import List, Optional +from typing import Optional import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + ImageParams, + Size, + SizeChart, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class UpscaleBSRGANStage(BaseStage): - max_tile = 64 + max_tile = SizeChart.micro def load( self, @@ -54,12 +62,12 @@ class UpscaleBSRGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: @@ -71,40 +79,38 @@ class UpscaleBSRGANStage(BaseStage): bsrgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources: - image = np.array(source) / 255.0 + for source in sources.as_numpy(): + image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) logger.trace("BSRGAN input shape: %s", image.shape) scale = upscale.outscale - dest = np.zeros( + logger.trace( + "BSRGAN output shape: %s", ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, - ) + ), ) - logger.trace("BSRGAN output shape: %s", dest.shape) - dest = bsrgan(image) + output = bsrgan(image) - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) - - output = Image.fromarray(dest, "RGB") - logger.debug("output image size: %s x %s", output.width, output.height) + output = np.clip(np.squeeze(output, axis=0), 0, 1) + output = output[[2, 1, 0], :, :].transpose((1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + logger.debug("output image shape: %s", output.shape) outputs.append(output) - return outputs + return StageResult(arrays=outputs) def steps( self, params: ImageParams, size: Size, ) -> int: - tile = min(params.tiles, self.max_tile) + tile = min(params.unet_tile, self.max_tile) return size.width // tile * size.height // tile diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index e19f75fb..32f891a6 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image @@ -8,7 +8,8 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from ..worker.context import ProgressCallback -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - sources: List[Image.Image], - *args, + sources: StageResult, + *, highres: HighresParams, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: if highres.scale <= 1: return sources chain = stage_highres(stage, params, highres, upscale) - return [ + outputs = [ chain( worker, server, @@ -41,5 +42,7 @@ class UpscaleHighresStage(BaseStage): source, callback=callback, ) - for source in sources + for source in sources.as_image() ] + + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 71de2629..464f5920 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional, Tuple import numpy as np import torch @@ -18,13 +18,14 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class UpscaleOutpaintStage(BaseStage): - max_tile = SizeChart.unlimited + max_tile = SizeChart.max def run( self, @@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage): server: ServerContext, stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, border: Border, dims: Tuple[int, int, int], @@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage): stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( params ) @@ -56,12 +57,12 @@ class UpscaleOutpaintStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) outputs = [] - for source in sources: + for source in sources.as_image(): if is_debug(): save_image(server, "tile-source.png", source) save_image(server, "tile-mask.png", tile_mask) @@ -71,7 +72,7 @@ class UpscaleOutpaintStage(BaseStage): outputs.append(source) continue - tile_size = params.tiles + tile_size = params.unet_tile size = Size(*source.size) latent_size = size.min(tile_size, tile_size) @@ -99,10 +100,11 @@ class UpscaleOutpaintStage(BaseStage): ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( @@ -121,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage): outputs.extend(result.images) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index e680af53..49b06ee4 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -1,8 +1,7 @@ from logging import getLogger from os import path -from typing import List, Optional +from typing import Optional -import numpy as np from PIL import Image from ..onnx import OnnxRRDBNet @@ -10,7 +9,8 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -77,25 +77,22 @@ class UpscaleRealESRGANStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) + upsampler = self.load( + server, upscale, worker.get_device(), tile=stage.tile_size + ) + outputs = [] - for source in sources: - output = np.array(source) - upsampler = self.load( - server, upscale, worker.get_device(), tile=stage.tile_size - ) - - output, _ = upsampler.enhance(output, outscale=upscale.outscale) - - output = Image.fromarray(output, "RGB") - logger.info("final output image size: %sx%s", output.width, output.height) + for source in sources.as_numpy(): + output, _ = upsampler.enhance(source, outscale=upscale.outscale) + logger.info("final output image size: %s", output.shape) outputs.append(output) - return outputs + return StageResult(arrays=outputs) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 7dd44200..7e939bd4 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -1,12 +1,13 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage): _server: ServerContext, _stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, method: str, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: if upscale.scale <= 1: logger.debug( "simple upscale stage run with scale of %s, skipping", upscale.scale @@ -32,18 +33,20 @@ class UpscaleSimpleStage(BaseStage): return sources outputs = [] - for source in sources: + for source in sources.as_image(): scaled_size = (source.width * upscale.scale, source.height * upscale.scale) if method == "bilinear": logger.debug("using bilinear interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + outputs.append( + source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + ) elif method == "lanczos": logger.debug("using Lanczos interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + outputs.append( + source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + ) else: logger.warning("unknown upscaling method: %s", method) - outputs.append(source) - - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 9d5a7b32..6c8a300e 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -1,8 +1,8 @@ from logging import getLogger from os import path -from typing import List, Optional +from typing import Optional -import torch +import numpy as np from PIL import Image from ..diffusers.load import load_pipeline @@ -10,7 +10,8 @@ from ..diffusers.utils import encode_prompt, parse_prompt from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) @@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage): server: ServerContext, _stage: StageParams, params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: params = params.with_args(**kwargs) upscale = upscale.with_args(**kwargs) logger.info( @@ -46,22 +47,23 @@ class UpscaleStableDiffusionStage(BaseStage): worker.get_device(), model=path.join(server.model_path, upscale.upscale_model), ) - generator = torch.manual_seed(params.seed) + rng = np.random.RandomState(params.seed) - prompt_embeds = encode_prompt( - pipeline, - prompt_pairs, - num_images_per_prompt=params.batch, - do_classifier_free_guidance=params.do_cfg(), - ) - pipeline.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipeline, + prompt_pairs, + num_images_per_prompt=params.batch, + do_classifier_free_guidance=params.do_cfg(), + ) + pipeline.unet.set_prompts(prompt_embeds) outputs = [] - for source in sources: + for source in sources.as_image(): result = pipeline( prompt, source, - generator=generator, + generator=rng, guidance_scale=params.cfg, negative_prompt=negative_prompt, num_inference_steps=params.steps, @@ -71,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage): ) outputs.extend(result.images) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index a49b99e5..ef7d421f 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -1,22 +1,23 @@ from logging import getLogger from os import path -from typing import List, Optional +from typing import Optional import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext -from .stage import BaseStage +from .base import BaseStage +from .result import StageResult logger = getLogger(__name__) class UpscaleSwinIRStage(BaseStage): - max_tile = 64 + max_tile = SizeChart.micro def load( self, @@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage): server: ServerContext, stage: StageParams, _params: ImageParams, - sources: List[Image.Image], + sources: StageResult, *, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, **kwargs, - ) -> List[Image.Image]: + ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: @@ -71,31 +72,30 @@ class UpscaleSwinIRStage(BaseStage): swinir = self.load(server, stage, upscale, device) outputs = [] - for source in sources: + for source in sources.as_numpy(): # TODO: add support for grayscale (1-channel) images - image = np.array(source) / 255.0 + image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) logger.trace("SwinIR input shape: %s", image.shape) scale = upscale.outscale - dest = np.zeros( + logger.trace( + "SwinIR output shape: %s", ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, - ) + ), ) - logger.trace("SwinIR output shape: %s", dest.shape) - dest = swinir(image) - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) + output = swinir(image) + output = np.clip(np.squeeze(output, axis=0), 0, 1) + output = output[[2, 1, 0], :, :].transpose((1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) - output = Image.fromarray(dest, "RGB") - logger.info("output image size: %s x %s", output.width, output.height) + logger.info("output image size: %s", output.shape) outputs.append(output) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/constants.py b/api/onnx_web/constants.py index 4fe47f98..0eb6c039 100644 --- a/api/onnx_web/constants.py +++ b/api/onnx_web/constants.py @@ -1,2 +1,5 @@ ONNX_MODEL = "model.onnx" ONNX_WEIGHTS = "weights.pb" + +LATENT_FACTOR = 8 +LATENT_CHANNELS = 4 diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 36c97429..5cbe7f07 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -15,7 +15,8 @@ from ..constants import ONNX_MODEL, ONNX_WEIGHTS from ..utils import load_config from .correction.gfpgan import convert_correction_gfpgan from .diffusion.control import convert_diffusion_control -from .diffusion.diffusers import convert_diffusion_diffusers +from .diffusion.diffusion import convert_diffusion_diffusers +from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl from .diffusion.lora import blend_loras from .diffusion.textual_inversion import blend_textual_inversions from .upscaling.bsrgan import convert_upscaling_bsrgan @@ -357,13 +358,23 @@ def convert_models(conversion: ConversionContext, args, models: Models): conversion, name, model["source"], format=model_format ) - converted, dest = convert_diffusion_diffusers( - conversion, - model, - source, - model_format, - hf=hf, - ) + pipeline = model.get("pipeline", "txt2img") + if pipeline.endswith("-sdxl"): + converted, dest = convert_diffusion_diffusers_xl( + conversion, + model, + source, + model_format, + hf=hf, + ) + else: + converted, dest = convert_diffusion_diffusers( + conversion, + model, + source, + model_format, + hf=hf, + ) # make sure blending only happens once, not every run if converted: @@ -588,7 +599,7 @@ def main(args=None) -> int: logger.info("CLI arguments: %s", args) server = ConversionContext.from_environ() - server.half = args.half or "onnx-fp16" in server.optimizations + server.half = args.half or server.has_optimization("onnx-fp16") server.opset = args.opset server.token = args.token logger.info( diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusion.py similarity index 97% rename from api/onnx_web/convert/diffusion/diffusers.py rename to api/onnx_web/convert/diffusion/diffusion.py index 9b90e5ef..a8ecbbf7 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -20,7 +20,6 @@ from diffusers import ( AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline, - StableDiffusionControlNetPipeline, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, @@ -32,17 +31,25 @@ from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...diffusers.load import optimize_pipeline +from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc -from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export +from ..utils import ( + RESOLVE_FORMATS, + ConversionContext, + check_ext, + is_torch_2_0, + load_tensor, + onnx_export, +) from .checkpoint import convert_extract_checkpoint logger = getLogger(__name__) -available_pipelines = { - "controlnet": StableDiffusionControlNetPipeline, +CONVERT_PIPELINES = { + "controlnet": OnnxStableDiffusionControlNetPipeline, "img2img": StableDiffusionPipeline, "inpaint": StableDiffusionPipeline, "lpw": StableDiffusionPipeline, @@ -96,7 +103,6 @@ def get_model_version( opts["prediction_type"] = "epsilon" except Exception: logger.debug("unable to load tensor for version check") - pass return (v2, opts) @@ -314,7 +320,7 @@ def convert_diffusion_diffusers( logger.info("ONNX model already exists, skipping") return (False, dest_path) - pipe_class = available_pipelines.get(pipe_type) + pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( source, conversion.map_location, size=image_size, version=version ) @@ -360,7 +366,6 @@ def convert_diffusion_diffusers( source, original_config_file=config_path, pipeline_class=pipe_class, - vae_path=replace_vae, **pipe_args, ).to(device, torch_dtype=dtype) elif hf: @@ -374,6 +379,17 @@ def convert_diffusion_diffusers( logger.warning("pipeline source not found or not recognized: %s", source) raise ValueError(f"pipeline source not found or not recognized: {source}") + if replace_vae is not None: + vae_path = path.join(conversion.model_path, replace_vae) + if check_ext(replace_vae, RESOLVE_FORMATS): + pipeline.vae = AutoencoderKL.from_single_file(vae_path) + else: + pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + + if is_torch_2_0: + pipeline.unet.set_attn_processor(AttnProcessor()) + pipeline.vae.set_attn_processor(AttnProcessor()) + optimize_pipeline(conversion, pipeline) output_path = Path(dest_path) @@ -424,9 +440,6 @@ def convert_diffusion_diffusers( unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) - if is_torch_2_0: - pipeline.unet.set_attn_processor(AttnProcessor()) - unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / ONNX_MODEL @@ -526,19 +539,6 @@ def convert_diffusion_diffusers( del unet run_gc() - # VAE - if replace_vae is not None: - if replace_vae.startswith("."): - logger.debug( - "custom VAE appears to be a local path, making it relative to the model path" - ) - replace_vae = path.join(conversion.model_path, replace_vae) - - logger.info("loading custom VAE: %s", replace_vae) - vae = AutoencoderKL.from_pretrained(replace_vae) - pipeline.vae = vae - run_gc() - if single_vae: logger.debug("VAE config: %s", pipeline.vae.config) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py new file mode 100644 index 00000000..7a03b700 --- /dev/null +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -0,0 +1,116 @@ +from logging import getLogger +from os import path +from typing import Dict, Optional, Tuple + +import onnx +import torch +from diffusers import AutoencoderKL, StableDiffusionXLPipeline +from onnx.shape_inference import infer_shapes_path +from onnxruntime.transformers.float16 import convert_float_to_float16 +from optimum.exporters.onnx import main_export + +from ...constants import ONNX_MODEL +from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext + +logger = getLogger(__name__) + + +@torch.no_grad() +def convert_diffusion_diffusers_xl( + conversion: ConversionContext, + model: Dict, + source: str, + format: Optional[str], + hf: bool = False, +) -> Tuple[bool, str]: + """ + From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py + """ + name = model.get("name") + replace_vae = model.get("vae", None) + + device = conversion.training_device + dtype = conversion.torch_dtype() + logger.debug("using Torch dtype %s for pipeline", dtype) + + dest_path = path.join(conversion.model_path, name) + model_index = path.join(dest_path, "model_index.json") + model_hash = path.join(dest_path, "hash.txt") + + # diffusers go into a directory rather than .onnx file + logger.info( + "converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path + ) + + if path.exists(dest_path) and path.exists(model_index): + logger.info("ONNX model already exists, skipping conversion") + + if "hash" in model and not path.exists(model_hash): + logger.info("ONNX model does not have hash file, adding one") + with open(model_hash, "w") as f: + f.write(model["hash"]) + + return (False, dest_path) + + # safetensors -> diffusers directory with torch models + temp_path = path.join(conversion.cache_path, f"{name}-torch") + + if format == "safetensors": + pipeline = StableDiffusionXLPipeline.from_single_file( + source, use_safetensors=True + ) + else: + pipeline = StableDiffusionXLPipeline.from_pretrained(source) + + if replace_vae is not None: + vae_path = path.join(conversion.model_path, replace_vae) + if check_ext(replace_vae, RESOLVE_FORMATS): + logger.debug("loading VAE from single tensor file: %s", vae_path) + pipeline.vae = AutoencoderKL.from_single_file(vae_path) + else: + logger.debug("loading pretrained VAE from path: %s", vae_path) + pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + + if path.exists(temp_path): + logger.debug("torch model already exists for %s: %s", source, temp_path) + else: + logger.debug("exporting torch model for %s: %s", source, temp_path) + pipeline.save_pretrained(temp_path) + + # directory -> onnx using optimum exporters + main_export( + temp_path, + output=dest_path, + task="stable-diffusion-xl", + device=device, + fp16=conversion.has_optimization( + "torch-fp16" + ), # optimum's fp16 mode only works on CUDA or ROCm + framework="pt", + ) + + if "hash" in model: + logger.debug("adding hash file to ONNX model") + with open(model_hash, "w") as f: + f.write(model["hash"]) + + if conversion.half: + unet_path = path.join(dest_path, "unet", ONNX_MODEL) + infer_shapes_path(unet_path) + unet = onnx.load(unet_path) + opt_model = convert_float_to_float16( + unet, + disable_shape_infer=True, + force_fp16_initializers=True, + keep_io_types=True, + op_block_list=["Attention", "MultiHeadAttention"], + ) + onnx.save_model( + opt_model, + unet_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + ) + + return False, dest_path diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index ae5b910a..c4ed7fa8 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,22 +1,15 @@ -from argparse import ArgumentParser from logging import getLogger -from os import path from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch -from onnx import ModelProto, load, numpy_helper -from onnx.checker import check_model -from onnx.external_data_helper import ( - convert_model_to_external_data, - set_external_data, - write_external_data_tensors, -) -from onnxruntime import InferenceSession, OrtValue, SessionOptions +from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper +from onnx.external_data_helper import set_external_data +from onnxruntime import OrtValue from scipy import interpolate from ...server.context import ServerContext -from ..utils import ConversionContext, load_tensor +from ..utils import load_tensor logger = getLogger(__name__) @@ -39,7 +32,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray: lr = a if kernel == (1, 1): - lr = np.expand_dims(lr, axis=(2, 3)) + lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis return hr + lr @@ -78,13 +71,15 @@ def fix_node_name(key: str): return fixed_name -def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): +def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]: fixed = {} + names = [fix_node_name(node.name) for node in nodes] for key, value in keys.items(): - root, *rest = key.split(".") - logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace + root, *_rest = key.split(".") + logger.trace("fixing XL node name: %s -> %s", key, root) + simple = False if root.startswith("input"): block = "down_blocks" elif root.startswith("middle"): @@ -93,6 +88,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): block = "up_blocks" elif root.startswith("text_model"): block = "text_model" + elif root.startswith("down_blocks"): + block = "down_blocks" + simple = True + elif root.startswith("mid_block"): + block = "mid_block" + simple = True + elif root.startswith("up_blocks"): + block = "up_blocks" + simple = True else: logger.warning("unknown XL key name: %s", key) fixed[key] = value @@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): suffix = None for s in [ + "conv", + "conv_shortcut", + "conv1", + "conv2", "fc1", "fc2", "ff_net_0_proj", @@ -119,18 +127,21 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): logger.warning("new XL key type: %s", root) continue - logger.debug("searching for XL node: /%s/*/%s", block, suffix) - match = None - if block == "text_model": - match = next( - node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul" - ) + logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix) + match: Optional[str] = None + if "conv" in suffix: + match = next(node for node in names if node == f"{root}_Conv") + elif "time_emb_proj" in root: + match = next(node for node in names if node == f"{root}_Gemm") + elif block == "text_model" or simple: + match = next(node for node in names if node == f"{root}_MatMul") else: + # search in order. one side has sparse indices, so they will not match. match = next( node - for node in nodes - if node.name.startswith(f"/{block}") - and fix_node_name(node.name).endswith( + for node in names + if node.startswith(block) + and node.endswith( f"{suffix}_MatMul" ) # needs to be fixed because some places use to_out.0 ) @@ -138,18 +149,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): if match is None: logger.warning("no matches for XL key: %s", root) continue + else: + logger.trace("matched key: %s -> %s", key, match) - name: str = match.name - name = fix_node_name(name.rstrip("/MatMul")) + name = match + if name.endswith("_MatMul"): + name = name[:-7] + elif name.endswith("_Gemm"): + name = name[:-5] + elif name.endswith("_Conv"): + name = name[:-5] - if name.endswith("proj_o"): - # wtf - name = f"{name}ut" - - logger.debug("matching XL key with node: %s -> %s", key, match.name) + logger.trace("matching XL key with node: %s -> %s, %s", key, match, name) fixed[name] = value - nodes.remove(match) + names.remove(match) + + logger.debug( + "SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining", + len(fixed.keys()), + len(keys.keys()), + len(names), + ) return fixed @@ -161,6 +182,245 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, ) +def blend_weights_loha( + key: str, lora_prefix: str, lora_model: Dict, dtype +) -> Tuple[str, np.ndarray]: + base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") + + t1_key = key.replace("hada_w1_a", "hada_t1") + t2_key = key.replace("hada_w1_a", "hada_t2") + w1b_key = key.replace("hada_w1_a", "hada_w1_b") + w2a_key = key.replace("hada_w1_a", "hada_w2_a") + w2b_key = key.replace("hada_w1_a", "hada_w2_b") + alpha_key = key[: key.index("hada_w1_a")] + "alpha" + logger.trace( + "blending weights for LoHA keys: %s, %s, %s, %s, %s", + key, + w1b_key, + w2a_key, + w2b_key, + alpha_key, + ) + + w1a_weight = lora_model[key].to(dtype=dtype) + w1b_weight = lora_model[w1b_key].to(dtype=dtype) + w2a_weight = lora_model[w2a_key].to(dtype=dtype) + w2b_weight = lora_model[w2b_key].to(dtype=dtype) + + t1_weight = lora_model.get(t1_key, None) + t2_weight = lora_model.get(t2_key, None) + + dim = w1b_weight.size()[0] + alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() + + if t1_weight is not None and t2_weight is not None: + t1_weight = t1_weight.to(dtype=dtype) + t2_weight = t2_weight.to(dtype=dtype) + + logger.trace( + "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", + t1_weight.shape, + w1a_weight.shape, + w1b_weight.shape, + t2_weight.shape, + w2a_weight.shape, + w2b_weight.shape, + ) + weights_1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + t1_weight, + w1b_weight, + w1a_weight, + ) + weights_2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + t2_weight, + w2b_weight, + w2a_weight, + ) + weights = weights_1 * weights_2 + np_weights = weights.numpy() * (alpha / dim) + else: + logger.trace( + "blending weights for LoHA node: (%s @ %s) * (%s @ %s)", + w1a_weight.shape, + w1b_weight.shape, + w2a_weight.shape, + w2b_weight.shape, + ) + weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) + np_weights = weights.numpy() * (alpha / dim) + + return base_key, np_weights + + +def blend_weights_lora( + key: str, lora_prefix: str, lora_model: Dict, dtype +) -> Tuple[str, np.ndarray]: + base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") + + mid_key = key.replace("lora_down", "lora_mid") + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key) + + down_weight = lora_model[key].to(dtype=dtype) + up_weight = lora_model[up_key].to(dtype=dtype) + + mid_weight = None + if mid_key in lora_model: + mid_weight = lora_model[mid_key].to(dtype=dtype) + + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key, dim) + + if not isinstance(alpha, int): + alpha = alpha.to(dtype).numpy() + + kernel = down_weight.shape[-2:] + if mid_weight is not None: + kernel = mid_weight.shape[-2:] + + if len(down_weight.size()) == 2: + # blend for nn.Linear + logger.trace( + "blending weights for Linear node: (%s @ %s) * %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = up_weight @ down_weight + np_weights = weights.numpy() * (alpha / dim) + elif len(down_weight.size()) == 4 and kernel == ( + 1, + 1, + ): + # blend for nn.Conv2d 1x1 + logger.trace( + "blending weights for Conv 1x1 node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = ( + (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)) + .unsqueeze(2) + .unsqueeze(3) + ) + np_weights = weights.numpy() * (alpha / dim) + elif len(down_weight.size()) == 4 and kernel == ( + 3, + 3, + ): + if mid_weight is not None: + # blend for nn.Conv2d 3x3 with CP decomp + logger.trace( + "composing weights for Conv 3x3 node: %s, %s, %s, %s", + down_weight.shape, + up_weight.shape, + mid_weight.shape, + alpha, + ) + weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel)) + + for w in range(kernel[0]): + for h in range(kernel[1]): + weights[:, :, w, h] = ( + up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h] + ) @ down_weight.squeeze(3).squeeze(2) + + np_weights = weights.numpy() * (alpha / dim) + else: + # blend for nn.Conv2d 3x3 + logger.trace( + "blending weights for Conv 3x3 node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) + weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel)) + + for w in range(kernel[0]): + for h in range(kernel[1]): + down_w, down_h = kernel_slice(w, h, down_weight.shape) + up_w, up_h = kernel_slice(w, h, up_weight.shape) + + weights[:, :, w, h] = ( + up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h] + ) + + np_weights = weights.numpy() * (alpha / dim) + else: + logger.warning( + "unknown LoRA node type at %s: %s", + base_key, + up_weight.shape[-2:], + ) + # TODO: should this be None? + np_weights = np.zeros((1, 1, 1, 1)) + + return base_key, np_weights + + +def blend_node_conv_gemm(weight_node, weights) -> TensorProto: + # blending + onnx_weights = numpy_helper.to_array(weight_node) + logger.trace( + "found blended weights for conv: %s, %s", + onnx_weights.shape, + weights.shape, + ) + + if onnx_weights.shape[-2:] == (1, 1): + if weights.shape[-2:] == (1, 1): + blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) + else: + blended = onnx_weights.squeeze((3, 2)) + weights + + blended = np.expand_dims(blended, (2, 3)) + else: + if onnx_weights.shape != weights.shape: + logger.warning( + "reshaping weights for mismatched Conv node: %s, %s", + onnx_weights.shape, + weights.shape, + ) + # TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus + blended = onnx_weights + weights.reshape(onnx_weights.shape) + else: + blended = onnx_weights + weights + + logger.trace("blended weight shape: %s", blended.shape) + + # replace the original initializer + return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name) + + +def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto: + onnx_weights = numpy_helper.to_array(matmul_node) + logger.trace( + "found blended weights for matmul: %s, %s", + weights.shape, + onnx_weights.shape, + ) + + t_weights = weights.transpose() + if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape: + logger.warning( + "weight shapes do not match for %s: %s vs %s", + matmul_key, + weights.shape, + onnx_weights.shape, + ) + t_weights = interp_to_match(weights, onnx_weights).transpose() + + blended = onnx_weights + t_weights + logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype) + + # replace the original initializer + return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name) + + def blend_loras( _conversion: ServerContext, base_name: Union[str, ModelProto], @@ -184,246 +444,77 @@ def blend_loras( else: lora_prefix = f"lora_{model_type}_" - blended: Dict[str, np.ndarray] = {} + layers = [] for (lora_name, lora_weight), lora_model in zip(loras, lora_models): logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight) if lora_model is None: logger.warning("unable to load tensor for LoRA") continue + blended: Dict[str, np.ndarray] = {} + layers.append(blended) + for key in lora_model.keys(): if ".hada_w1_a" in key and lora_prefix in key: # LoHA - base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") - - t1_key = key.replace("hada_w1_a", "hada_t1") - t2_key = key.replace("hada_w1_a", "hada_t2") - w1b_key = key.replace("hada_w1_a", "hada_w1_b") - w2a_key = key.replace("hada_w1_a", "hada_w2_a") - w2b_key = key.replace("hada_w1_a", "hada_w2_b") - alpha_key = key[: key.index("hada_w1_a")] + "alpha" - logger.trace( - "blending weights for LoHA keys: %s, %s, %s, %s, %s", - key, - w1b_key, - w2a_key, - w2b_key, - alpha_key, + base_key, np_weights = blend_weights_loha( + key, lora_prefix, lora_model, dtype ) - - w1a_weight = lora_model[key].to(dtype=dtype) - w1b_weight = lora_model[w1b_key].to(dtype=dtype) - w2a_weight = lora_model[w2a_key].to(dtype=dtype) - w2b_weight = lora_model[w2b_key].to(dtype=dtype) - - t1_weight = lora_model.get(t1_key, None) - t2_weight = lora_model.get(t2_key, None) - - dim = w1b_weight.size()[0] - alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() - - if t1_weight is not None and t2_weight is not None: - t1_weight = t1_weight.to(dtype=dtype) - t2_weight = t2_weight.to(dtype=dtype) - - logger.trace( - "composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)", - t1_weight.shape, - w1a_weight.shape, - w1b_weight.shape, - t2_weight.shape, - w2a_weight.shape, - w2b_weight.shape, - ) - weights_1 = torch.einsum( - "i j k l, j r, i p -> p r k l", - t1_weight, - w1b_weight, - w1a_weight, - ) - weights_2 = torch.einsum( - "i j k l, j r, i p -> p r k l", - t2_weight, - w2b_weight, - w2a_weight, - ) - weights = weights_1 * weights_2 - np_weights = weights.numpy() * (alpha / dim) - else: - logger.trace( - "blending weights for LoHA node: (%s @ %s) * (%s @ %s)", - w1a_weight.shape, - w1b_weight.shape, - w2a_weight.shape, - w2b_weight.shape, - ) - weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight) - np_weights = weights.numpy() * (alpha / dim) - - np_weights *= lora_weight - if base_key in blended: - logger.trace( - "summing LoHA weights: %s + %s", - blended[base_key].shape, - np_weights.shape, - ) - blended[base_key] += sum_weights(blended[base_key], np_weights) - else: - blended[base_key] = np_weights + np_weights = np_weights * lora_weight + logger.trace( + "adding LoHA weights: %s", + np_weights.shape, + ) + blended[base_key] = np_weights elif ".lora_down" in key and lora_prefix in key: # LoRA or LoCON - base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") - - mid_key = key.replace("lora_down", "lora_mid") - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - logger.trace( - "blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key + base_key, np_weights = blend_weights_lora( + key, lora_prefix, lora_model, dtype ) + np_weights = np_weights * lora_weight + logger.trace( + "adding LoRA weights: %s", + np_weights.shape, + ) + blended[base_key] = np_weights - down_weight = lora_model[key].to(dtype=dtype) - up_weight = lora_model[up_key].to(dtype=dtype) + # rewrite node names for XL and flatten layers + weights: Dict[str, np.ndarray] = {} - mid_weight = None - if mid_key in lora_model: - mid_weight = lora_model[mid_key].to(dtype=dtype) + for blended in layers: + if xl: + nodes = list(base_model.graph.node) + blended = fix_xl_names(blended, nodes) - dim = down_weight.size()[0] - alpha = lora_model.get(alpha_key, dim) - - if not isinstance(alpha, int): - alpha = alpha.to(dtype).numpy() - - kernel = down_weight.shape[-2:] - if mid_weight is not None: - kernel = mid_weight.shape[-2:] - - if len(down_weight.size()) == 2: - # blend for nn.Linear - logger.trace( - "blending weights for Linear node: (%s @ %s) * %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = up_weight @ down_weight - np_weights = weights.numpy() * (alpha / dim) - elif len(down_weight.size()) == 4 and kernel == ( - 1, - 1, - ): - # blend for nn.Conv2d 1x1 - logger.trace( - "blending weights for Conv 1x1 node: %s, %s, %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = ( - ( - up_weight.squeeze(3).squeeze(2) - @ down_weight.squeeze(3).squeeze(2) - ) - .unsqueeze(2) - .unsqueeze(3) - ) - np_weights = weights.numpy() * (alpha / dim) - elif len(down_weight.size()) == 4 and kernel == ( - 3, - 3, - ): - if mid_weight is not None: - # blend for nn.Conv2d 3x3 with CP decomp - logger.trace( - "composing weights for Conv 3x3 node: %s, %s, %s, %s", - down_weight.shape, - up_weight.shape, - mid_weight.shape, - alpha, - ) - weights = torch.zeros( - (up_weight.shape[0], down_weight.shape[1], *kernel) - ) - - for w in range(kernel[0]): - for h in range(kernel[1]): - weights[:, :, w, h] = ( - up_weight.squeeze(3).squeeze(2) - @ mid_weight[:, :, w, h] - ) @ down_weight.squeeze(3).squeeze(2) - - np_weights = weights.numpy() * (alpha / dim) - else: - # blend for nn.Conv2d 3x3 - logger.trace( - "blending weights for Conv 3x3 node: %s, %s, %s", - down_weight.shape, - up_weight.shape, - alpha, - ) - weights = torch.zeros( - (up_weight.shape[0], down_weight.shape[1], *kernel) - ) - - for w in range(kernel[0]): - for h in range(kernel[1]): - down_w, down_h = kernel_slice(w, h, down_weight.shape) - up_w, up_h = kernel_slice(w, h, up_weight.shape) - - weights[:, :, w, h] = ( - up_weight[:, :, up_w, up_h] - @ down_weight[:, :, down_w, down_h] - ) - - np_weights = weights.numpy() * (alpha / dim) - else: - logger.warning( - "unknown LoRA node type at %s: %s", - base_key, - up_weight.shape[-2:], - ) - continue - - np_weights *= lora_weight - if base_key in blended: - logger.trace( - "summing weights: %s + %s", - blended[base_key].shape, - np_weights.shape, - ) - blended[base_key] = sum_weights(blended[base_key], np_weights) - else: - blended[base_key] = np_weights - - # rewrite node names for XL - if xl: - nodes = list(base_model.graph.node) - blended = fix_xl_names(blended, nodes) - - logger.trace( - "updating %s of %s initializers", - len(blended.keys()), - len(base_model.graph.initializer), - ) + for key, value in blended.items(): + if key in weights: + weights[key] = sum_weights(weights[key], value) + else: + weights[key] = value + # fix node names once fixed_initializer_names = [ fix_initializer_name(node.name) for node in base_model.graph.initializer ] - logger.trace("fixed initializer names: %s", fixed_initializer_names) - fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] - logger.trace("fixed node names: %s", fixed_node_names) + + logger.debug( + "updating %s of %s initializers", + len(weights.keys()), + len(base_model.graph.initializer), + ) unmatched_keys = [] - for base_key, weights in blended.items(): + for base_key, weights in weights.items(): conv_key = base_key + "_Conv" gemm_key = base_key + "_Gemm" matmul_key = base_key + "_MatMul" logger.trace( - "key %s has conv: %s, matmul: %s", + "key %s has conv: %s, gemm: %s, matmul: %s", base_key, conv_key in fixed_node_names, + gemm_key in fixed_node_names, matmul_key in fixed_node_names, ) @@ -449,38 +540,9 @@ def blend_loras( weight_node = base_model.graph.initializer[weight_idx] logger.trace("found weight initializer: %s", weight_node.name) - # blending - onnx_weights = numpy_helper.to_array(weight_node) - logger.trace( - "found blended weights for conv: %s, %s", - onnx_weights.shape, - weights.shape, - ) + # replace the previous node + updated_node = blend_node_conv_gemm(weight_node, weights) - if onnx_weights.shape[-2:] == (1, 1): - if weights.shape[-2:] == (1, 1): - blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) - else: - blended = onnx_weights.squeeze((3, 2)) + weights - - blended = np.expand_dims(blended, (2, 3)) - else: - if onnx_weights.shape != weights.shape: - logger.warning( - "reshaping weights for mismatched Conv node: %s, %s", - onnx_weights.shape, - weights.shape, - ) - blended = onnx_weights + weights.reshape(onnx_weights.shape) - else: - blended = onnx_weights + weights - - logger.trace("blended weight shape: %s", blended.shape) - - # replace the original initializer - updated_node = numpy_helper.from_array( - blended.astype(onnx_weights.dtype), weight_node.name - ) del base_model.graph.initializer[weight_idx] base_model.graph.initializer.insert(weight_idx, updated_node) elif matmul_key in fixed_node_names: @@ -497,42 +559,15 @@ def blend_loras( matmul_node = base_model.graph.initializer[matmul_idx] logger.trace("found matmul initializer: %s", matmul_node.name) - # blending - onnx_weights = numpy_helper.to_array(matmul_node) - logger.trace( - "found blended weights for matmul: %s, %s", - weights.shape, - onnx_weights.shape, - ) + # replace the previous node + updated_node = blend_node_matmul(matmul_node, weights, matmul_key) - t_weights = weights.transpose() - if ( - weights.shape != onnx_weights.shape - and t_weights.shape != onnx_weights.shape - ): - logger.warning( - "weight shapes do not match for %s: %s vs %s", - matmul_key, - weights.shape, - onnx_weights.shape, - ) - t_weights = interp_to_match(weights, onnx_weights).transpose() - - blended = onnx_weights + t_weights - logger.debug( - "blended weight shape: %s, %s", blended.shape, onnx_weights.dtype - ) - - # replace the original initializer - updated_node = numpy_helper.from_array( - blended.astype(onnx_weights.dtype), matmul_node.name - ) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: unmatched_keys.append(base_key) - logger.debug( + logger.trace( "node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), @@ -541,10 +576,7 @@ def blend_loras( ) if len(unmatched_keys) > 0: - logger.warning("could not find nodes for some keys: %s", unmatched_keys) - - # if model_type == "unet": - # save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb") + logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys) return base_model @@ -568,63 +600,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray: logger.debug("weights after interpolation: %s", output.shape) return output - - -if __name__ == "__main__": - context = ConversionContext.from_environ() - parser = ArgumentParser() - parser.add_argument("--base", type=str) - parser.add_argument("--dest", type=str) - parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) - parser.add_argument("--lora_models", nargs="+", type=str, default=[]) - parser.add_argument("--lora_weights", nargs="+", type=float, default=[]) - - args = parser.parse_args() - logger.info( - "merging %s with %s with weights: %s", - args.lora_models, - args.base, - args.lora_weights, - ) - - default_weight = 1.0 / len(args.lora_models) - while len(args.lora_weights) < len(args.lora_models): - args.lora_weights.append(default_weight) - - blend_model = blend_loras( - context, - args.base, - list(zip(args.lora_models, args.lora_weights)), - args.type, - ) - if args.dest is None or args.dest == "" or args.dest == ":load": - # convert to external data and save to memory - (bare_model, external_data) = buffer_external_data_tensors(blend_model) - logger.info("saved external data for %s nodes", len(external_data)) - - external_names, external_values = zip(*external_data) - opts = SessionOptions() - opts.add_external_initializers(list(external_names), list(external_values)) - sess = InferenceSession( - bare_model.SerializeToString(), - sess_options=opts, - providers=["CPUExecutionProvider"], - ) - logger.info( - "successfully loaded blended model: %s", [i.name for i in sess.get_inputs()] - ) - else: - convert_model_to_external_data( - blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb" - ) - bare_model = write_external_data_tensors(blend_model, args.dest) - dest_file = path.join(args.dest, f"lora-{args.type}.onnx") - - with open(dest_file, "w+b") as model_file: - model_file.write(bare_model.SerializeToString()) - - logger.info("successfully saved blended model: %s", dest_file) - - check_model(dest_file) - - logger.info("checked blended model") diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 3eece453..02d72eab 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -14,19 +14,155 @@ from ..utils import ConversionContext, load_tensor logger = getLogger(__name__) +def detect_embedding_format(loaded_embeds) -> str: + keys: List[str] = list(loaded_embeds.keys()) + if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): + logger.debug("detected Textual Inversion concept: %s", keys) + return "concept" + elif "emb_params" in keys: + logger.debug("detected Textual Inversion parameter embeddings: %s", keys) + return "parameters" + elif "string_to_token" in keys and "string_to_param" in keys: + logger.debug("detected Textual Inversion token embeddings: %s", keys) + return "embeddings" + else: + logger.error("unknown Textual Inversion format, no recognized keys: %s", keys) + return None + + +def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight): + # separate token and the embeds + token = list(loaded_embeds.keys())[0] + + layer = loaded_embeds[token].numpy().astype(dtype) + layer *= weight + + if base_token in embeds: + embeds[base_token] += layer + else: + embeds[base_token] = layer + + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + +def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight): + emb_params = loaded_embeds["emb_params"] + + num_tokens = emb_params.shape[0] + logger.debug("generating %s layer tokens for %s", num_tokens, base_token) + + sum_layer = np.zeros(emb_params[0, :].shape) + + for i in range(num_tokens): + token = f"{base_token}-{i}" + layer = emb_params[i, :].numpy().astype(dtype) + layer *= weight + + sum_layer += layer + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + # add base and sum tokens to embeds + if base_token in embeds: + embeds[base_token] += sum_layer + else: + embeds[base_token] = sum_layer + + sum_token = f"{base_token}-all" + if sum_token in embeds: + embeds[sum_token] += sum_layer + else: + embeds[sum_token] = sum_layer + + +def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight): + string_to_token = loaded_embeds["string_to_token"] + string_to_param = loaded_embeds["string_to_param"] + + # separate token and embeds + token = list(string_to_token.keys())[0] + trained_embeds = string_to_param[token] + + num_tokens = trained_embeds.shape[0] + logger.debug("generating %s layer tokens for %s", num_tokens, base_token) + + sum_layer = np.zeros(trained_embeds[0, :].shape) + + for i in range(num_tokens): + token = f"{base_token}-{i}" + layer = trained_embeds[i, :].numpy().astype(dtype) + layer *= weight + + sum_layer += layer + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + # add base and sum tokens to embeds + if base_token in embeds: + embeds[base_token] += sum_layer + else: + embeds[base_token] = sum_layer + + sum_token = f"{base_token}-all" + if sum_token in embeds: + embeds[sum_token] += sum_layer + else: + embeds[sum_token] = sum_layer + + +def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens): + # resize the token embeddings + # text_encoder.resize_token_embeddings(len(tokenizer)) + embedding_node = [ + n + for n in text_encoder.graph.initializer + if n.name == "text_model.embeddings.token_embedding.weight" + ][0] + base_weights = numpy_helper.to_array(embedding_node) + + weights_dim = base_weights.shape[1] + zero_weights = np.zeros((num_added_tokens, weights_dim)) + embedding_weights = np.concatenate((base_weights, zero_weights), axis=0) + + for token, weights in embeds.items(): + token_id = tokenizer.convert_tokens_to_ids(token) + logger.trace("embedding %s weights for token %s", weights.shape, token) + embedding_weights[token_id] = weights + + # replace embedding_node + for i in range(len(text_encoder.graph.initializer)): + if ( + text_encoder.graph.initializer[i].name + == "text_model.embeddings.token_embedding.weight" + ): + new_initializer = numpy_helper.from_array( + embedding_weights.astype(base_weights.dtype), embedding_node.name + ) + logger.trace("new initializer data type: %s", new_initializer.data_type) + del text_encoder.graph.initializer[i] + text_encoder.graph.initializer.insert(i, new_initializer) + + @torch.no_grad() def blend_textual_inversions( server: ServerContext, text_encoder: ModelProto, tokenizer: CLIPTokenizer, - inversions: List[Tuple[str, float, Optional[str], Optional[str]]], + embeddings: List[Tuple[str, float, Optional[str], Optional[str]]], ) -> Tuple[ModelProto, CLIPTokenizer]: # always load to CPU for blending device = torch.device("cpu") dtype = np.float32 embeds = {} - for name, weight, base_token, inversion_format in inversions: + for name, weight, base_token, format in embeddings: if base_token is None: logger.debug("no base token provided, using name: %s", name) base_token = name @@ -43,153 +179,28 @@ def blend_textual_inversions( logger.warning("unable to load tensor") continue - if inversion_format is None: - keys: List[str] = list(loaded_embeds.keys()) - if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): - logger.debug("detected Textual Inversion concept: %s", keys) - inversion_format = "concept" - elif "emb_params" in keys: - logger.debug( - "detected Textual Inversion parameter embeddings: %s", keys - ) - inversion_format = "parameters" - elif "string_to_token" in keys and "string_to_param" in keys: - logger.debug("detected Textual Inversion token embeddings: %s", keys) - inversion_format = "embeddings" - else: - logger.error( - "unknown Textual Inversion format, no recognized keys: %s", keys - ) - continue + if format is None: + format = detect_embedding_format(loaded_embeds) - if inversion_format == "concept": - # separate token and the embeds - token = list(loaded_embeds.keys())[0] - - layer = loaded_embeds[token].numpy().astype(dtype) - layer *= weight - - if base_token in embeds: - embeds[base_token] += layer - else: - embeds[base_token] = layer - - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - elif inversion_format == "parameters": - emb_params = loaded_embeds["emb_params"] - - num_tokens = emb_params.shape[0] - logger.debug("generating %s layer tokens for %s", num_tokens, name) - - sum_layer = np.zeros(emb_params[0, :].shape) - - for i in range(num_tokens): - token = f"{base_token}-{i}" - layer = emb_params[i, :].numpy().astype(dtype) - layer *= weight - - sum_layer += layer - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - - # add base and sum tokens to embeds - if base_token in embeds: - embeds[base_token] += sum_layer - else: - embeds[base_token] = sum_layer - - sum_token = f"{base_token}-all" - if sum_token in embeds: - embeds[sum_token] += sum_layer - else: - embeds[sum_token] = sum_layer - elif inversion_format == "embeddings": - string_to_token = loaded_embeds["string_to_token"] - string_to_param = loaded_embeds["string_to_param"] - - # separate token and embeds - token = list(string_to_token.keys())[0] - trained_embeds = string_to_param[token] - - num_tokens = trained_embeds.shape[0] - logger.debug("generating %s layer tokens for %s", num_tokens, name) - - sum_layer = np.zeros(trained_embeds[0, :].shape) - - for i in range(num_tokens): - token = f"{base_token}-{i}" - layer = trained_embeds[i, :].numpy().astype(dtype) - layer *= weight - - sum_layer += layer - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - - # add base and sum tokens to embeds - if base_token in embeds: - embeds[base_token] += sum_layer - else: - embeds[base_token] = sum_layer - - sum_token = f"{base_token}-all" - if sum_token in embeds: - embeds[sum_token] += sum_layer - else: - embeds[sum_token] = sum_layer + if format == "concept": + blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight) + elif format == "parameters": + blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight) + elif format == "embeddings": + blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight) else: - raise ValueError(f"unknown Textual Inversion format: {inversion_format}") + raise ValueError(f"unknown Textual Inversion format: {format}") - # add the tokens to the tokenizer - logger.debug( - "found embeddings for %s tokens: %s", - len(embeds.keys()), - list(embeds.keys()), + # add the tokens to the tokenizer + num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) + if num_added_tokens == 0: + raise ValueError( + "The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer." ) - num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) - if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." - ) - logger.trace("added %s tokens", num_added_tokens) + logger.trace("added %s tokens", num_added_tokens) - # resize the token embeddings - # text_encoder.resize_token_embeddings(len(tokenizer)) - embedding_node = [ - n - for n in text_encoder.graph.initializer - if n.name == "text_model.embeddings.token_embedding.weight" - ][0] - base_weights = numpy_helper.to_array(embedding_node) - - weights_dim = base_weights.shape[1] - zero_weights = np.zeros((num_added_tokens, weights_dim)) - embedding_weights = np.concatenate((base_weights, zero_weights), axis=0) - - for token, weights in embeds.items(): - token_id = tokenizer.convert_tokens_to_ids(token) - logger.trace("embedding %s weights for token %s", weights.shape, token) - embedding_weights[token_id] = weights - - # replace embedding_node - for i in range(len(text_encoder.graph.initializer)): - if ( - text_encoder.graph.initializer[i].name - == "text_model.embeddings.token_embedding.weight" - ): - new_initializer = numpy_helper.from_array( - embedding_weights.astype(base_weights.dtype), embedding_node.name - ) - logger.trace("new initializer data type: %s", new_initializer.data_type) - del text_encoder.graph.initializer[i] - text_encoder.graph.initializer.insert(i, new_initializer) + blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens) return (text_encoder, tokenizer) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9ed7424f..ef44ba20 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -36,7 +36,7 @@ DEFAULT_OPSET = 14 class ConversionContext(ServerContext): def __init__( self, - model_path: Optional[str] = None, + model_path: str = ".", cache_path: Optional[str] = None, device: Optional[str] = None, half: bool = False, @@ -69,7 +69,7 @@ class ConversionContext(ServerContext): def from_environ(cls): context = super().from_environ() context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True) - context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True) + context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False) context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True) context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True) context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET)) @@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]): def tuple_to_source(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): - name, source, *rest = model + name, source, *_rest = model return { "name": name, @@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]): def tuple_to_correction(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - scale = rest[0] if len(rest) > 0 else 1 - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + scale = rest.pop(0) if len(rest) > 0 else 1 + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]): def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - single_vae = rest[0] if len(rest) > 0 else False - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + single_vae = rest.pop(0) if len(rest) > 0 else False + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - scale = rest[0] if len(rest) > 0 else 1 - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + scale = rest.pop(0) if len(rest) > 0 else 1 + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] -RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"] +RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"] + + +def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]: + _name, ext = path.splitext(name) + ext = ext.strip(".") + + return (ext in exts, ext) def source_format(model: Dict) -> Optional[str]: @@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]: return model["format"] if "source" in model: - _name, ext = path.splitext(model["source"]) - if ext in MODEL_FORMATS: + valid, ext = check_ext(model["source"], MODEL_FORMATS) + if valid: return ext return None @@ -298,6 +305,7 @@ def onnx_export( half=False, external_data=False, v2=False, + op_block_list=None, ): """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -316,8 +324,7 @@ def onnx_export( opset_version=opset, ) - op_block_list = None - if v2: + if v2 and op_block_list is None: op_block_list = ["Attention", "MultiHeadAttention"] if half: diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 78123449..30a87863 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,16 +1,15 @@ from logging import getLogger from os import path -from typing import Any, List, Optional, Tuple +from typing import Any, List, Literal, Optional, Tuple from onnx import load_model from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLPipeline, ) -from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet from transformers import CLIPTokenizer -from ..constants import ONNX_MODEL +from ..constants import LATENT_FACTOR, ONNX_MODEL from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline @@ -24,6 +23,7 @@ from .patches.vae import VAEWrapper from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline +from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline from .version_safe_diffusers import ( DDIMScheduler, @@ -38,6 +38,7 @@ from .version_safe_diffusers import ( KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, + LCMScheduler, LMSDiscreteScheduler, OnnxRuntimeModel, OnnxStableDiffusionImg2ImgPipeline, @@ -58,6 +59,7 @@ available_pipelines = { # "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline, "lpw": OnnxStableDiffusionLongPromptWeightingPipeline, "panorama": OnnxStableDiffusionPanoramaPipeline, + "panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline, "pix2pix": OnnxStableDiffusionInstructPix2PixPipeline, "txt2img-sdxl": ORTStableDiffusionXLPipeline, "txt2img": OnnxStableDiffusionPipeline, @@ -77,12 +79,25 @@ pipeline_schedulers = { "k-dpm-2-a": KDPM2AncestralDiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler, "karras-ve": KarrasVeScheduler, + "lcm": LCMScheduler, "lms-discrete": LMSDiscreteScheduler, "pndm": PNDMScheduler, "unipc-multi": UniPCMultistepScheduler, } +def add_pipeline(name: str, pipeline: Any) -> bool: + global available_pipelines + + if name in available_pipelines: + # TODO: decide if this should be allowed or not + logger.warning("cannot replace existing pipeline: %s", name) + return False + else: + available_pipelines[name] = pipeline + return True + + def get_available_pipelines() -> List[str]: return list(available_pipelines.keys()) @@ -99,16 +114,19 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]: return None +VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"] + + def load_pipeline( server: ServerContext, params: ImageParams, pipeline: str, device: DeviceParams, - inversions: Optional[List[Tuple[str, float]]] = None, + embeddings: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, model: Optional[str] = None, ): - inversions = inversions or [] + embeddings = embeddings or [] loras = loras or [] model = model or params.model @@ -122,7 +140,7 @@ def load_pipeline( device.device, device.provider, control_key, - inversions, + embeddings, loras, ) scheduler_key = (params.scheduler, model) @@ -159,211 +177,376 @@ def load_pipeline( run_gc([device]) logger.debug("loading new diffusion pipeline from %s", model) + scheduler = scheduler_type.from_pretrained( + model, + provider=device.ort_provider(), + sess_options=device.sess_options(), + subfolder="scheduler", + torch_dtype=torch_dtype, + ) components = { - "scheduler": scheduler_type.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - subfolder="scheduler", - torch_dtype=torch_dtype, - ) + "scheduler": scheduler, } # shared components - text_encoder = None unet_type = "unet" # ControlNet component if params.is_control() and params.control is not None: - cnet_path = path.join( - server.model_path, "control", f"{params.control.name}.onnx" - ) - logger.debug("loading ControlNet weights from %s", cnet_path) - components["controlnet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - cnet_path, - provider=device.ort_provider(), - sess_options=device.sess_options(), - ) - ) - + logger.debug("loading ControlNet components") + control_components = load_controlnet(server, device, params) + components.update(control_components) unet_type = "cnet" - # Textual Inversion blending - if inversions is not None and len(inversions) > 0: - logger.debug("blending Textual Inversions from %s", inversions) - inversion_names, inversion_weights = zip(*inversions) + # load various pipeline components + encoder_components = load_text_encoders( + server, device, model, embeddings, loras, torch_dtype, params + ) + components.update(encoder_components) - inversion_models = [ - path.join(server.model_path, "inversion", name) - for name in inversion_names - ] - text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) - tokenizer = CLIPTokenizer.from_pretrained( - model, - subfolder="tokenizer", - torch_dtype=torch_dtype, + unet_components = load_unet(server, device, model, loras, unet_type, params) + components.update(unet_components) + + vae_components = load_vae(server, device, model, params) + components.update(vae_components) + + pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) + + if params.is_xl(): + logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__) + pipe = pipeline_class( + components["vae_decoder_session"], + components["text_encoder_session"], + components["unet_session"], + { + "force_zeros_for_empty_prompt": True, + "requires_aesthetics_score": False, + }, + components["tokenizer"], + scheduler, + vae_encoder_session=components.get("vae_encoder_session", None), + text_encoder_2_session=components.get("text_encoder_2_session", None), + tokenizer_2=components.get("tokenizer_2", None), ) - text_encoder, tokenizer = blend_textual_inversions( + else: + if "vae" in components: + # upscale uses a single VAE + logger.debug( + "assembling SD pipeline for %s with single VAE", + pipeline_class.__name__, + ) + pipe = pipeline_class( + components["vae"], + components["text_encoder"], + components["tokenizer"], + components["unet"], + scheduler, + scheduler, + ) + else: + logger.debug( + "assembling SD pipeline for %s with VAE codec", + pipeline_class.__name__, + ) + pipe = pipeline_class( + components["vae_encoder"], + components["vae_decoder"], + components["text_encoder"], + components["tokenizer"], + components["unet"], + scheduler, + None, + None, + requires_safety_checker=False, + ) + + if not server.show_progress: + pipe.set_progress_bar_config(disable=True) + + optimize_pipeline(server, pipe) + patch_pipeline(server, pipe, pipeline_class, params) + + server.cache.set(ModelTypes.diffusion, pipe_key, pipe) + server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler) + + for vae in VAE_COMPONENTS: + if hasattr(pipe, vae): + vae_model = getattr(pipe, vae) + if isinstance(vae_model, VAEWrapper): + vae_model.set_tiled(tiled=params.tiled_vae) + vae_model.set_window_size( + params.vae_tile // LATENT_FACTOR, params.vae_overlap + ) + + # update panorama params + if params.is_panorama(): + unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR + logger.debug( + "setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", + params.unet_tile, + unet_stride, + params.vae_tile, + params.vae_overlap, + ) + pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride) + + run_gc([device]) + + return pipe + + +def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams): + cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx") + logger.debug("loading ControlNet weights from %s", cnet_path) + components = {} + components["controlnet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + cnet_path, + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) + ) + return components + + +def load_text_encoders( + server: ServerContext, + device: DeviceParams, + model: str, + embeddings: Optional[List[Tuple[str, float]]], + loras: Optional[List[Tuple[str, float]]], + torch_dtype, + params: ImageParams, +): + text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) + tokenizer = CLIPTokenizer.from_pretrained( + model, + subfolder="tokenizer", + torch_dtype=torch_dtype, + ) + + components = { + "tokenizer": tokenizer, + } + + if params.is_xl(): + text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) + tokenizer_2 = CLIPTokenizer.from_pretrained( + model, + subfolder="tokenizer_2", + torch_dtype=torch_dtype, + ) + components["tokenizer_2"] = tokenizer_2 + + # blend embeddings, if any + if embeddings is not None and len(embeddings) > 0: + embedding_names, embedding_weights = zip(*embeddings) + embedding_models = [ + path.join(server.model_path, "inversion", name) for name in embedding_names + ] + logger.debug( + "blending base model %s with embeddings from %s", model, embedding_models + ) + + # TODO: blend text_encoder_2 as well + text_encoder, tokenizer = blend_textual_inversions( + server, + text_encoder, + tokenizer, + list( + zip( + embedding_models, + embedding_weights, + embedding_names, + [None] * len(embedding_models), + ) + ), + ) + components["tokenizer"] = tokenizer + + if params.is_xl(): + text_encoder_2, tokenizer_2 = blend_textual_inversions( server, - text_encoder, - tokenizer, + text_encoder_2, + tokenizer_2, list( zip( - inversion_models, - inversion_weights, - inversion_names, - [None] * len(inversion_models), + embedding_models, + embedding_weights, + embedding_names, + [None] * len(embedding_models), ) ), ) + components["tokenizer_2"] = tokenizer_2 - components["tokenizer"] = tokenizer + # blend LoRAs, if any + if loras is not None and len(loras) > 0: + lora_names, lora_weights = zip(*loras) + lora_models = [ + path.join(server.model_path, "lora", name) for name in lora_names + ] + logger.info("blending base model %s with LoRAs from %s", model, lora_models) - # should be pretty small and should not need external data - if loras is None or len(loras) == 0: - # TODO: handle XL encoders - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=device.sess_options(), - ) - ) + # blend and load text encoder + text_encoder = blend_loras( + server, + text_encoder, + list(zip(lora_models, lora_weights)), + "text_encoder", + 1 if params.is_xl() else None, + params.is_xl(), + ) - # LoRA blending - if loras is not None and len(loras) > 0: - lora_names, lora_weights = zip(*loras) - lora_models = [ - path.join(server.model_path, "lora", name) for name in lora_names - ] - logger.info( - "blending base model %s with LoRA models: %s", model, lora_models - ) - - # blend and load text encoder - text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL) - text_encoder = blend_loras( + if params.is_xl(): + text_encoder_2 = blend_loras( server, - text_encoder, + text_encoder_2, list(zip(lora_models, lora_weights)), "text_encoder", - 1 if params.is_xl() else None, + 2, params.is_xl(), ) - (text_encoder, text_encoder_data) = buffer_external_data_tensors( - text_encoder + + # prepare external data for sessions + (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder) + text_encoder_names, text_encoder_values = zip(*text_encoder_data) + text_encoder_opts = device.sess_options(cache=False) + text_encoder_opts.add_external_initializers( + list(text_encoder_names), list(text_encoder_values) + ) + + if params.is_xl(): + # encoder 2 only exists in XL + (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( + text_encoder_2 + ) + text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) + text_encoder_2_opts = device.sess_options(cache=False) + text_encoder_2_opts.add_external_initializers( + list(text_encoder_2_names), list(text_encoder_2_values) + ) + + # session for te1 + text_encoder_session = InferenceSession( + text_encoder.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_opts, + ) + text_encoder_session._model_path = path.join(model, "text_encoder") + components["text_encoder_session"] = text_encoder_session + + # session for te2 + text_encoder_2_session = InferenceSession( + text_encoder_2.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_2_opts, + ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session + else: + # session for te + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder.SerializeToString(), + provider=device.ort_provider("text-encoder"), + sess_options=text_encoder_opts, ) - text_encoder_names, text_encoder_values = zip(*text_encoder_data) - text_encoder_opts = device.sess_options(cache=False) - text_encoder_opts.add_external_initializers( - list(text_encoder_names), list(text_encoder_values) + ) + + return components + + +def load_unet( + server: ServerContext, + device: DeviceParams, + model: str, + loras: List[Tuple[str, float]], + unet_type: Literal["cnet", "unet"], + params: ImageParams, +): + components = {} + unet = load_model(path.join(model, unet_type, ONNX_MODEL)) + + # LoRA blending + if loras is not None and len(loras) > 0: + lora_names, lora_weights = zip(*loras) + lora_models = [ + path.join(server.model_path, "lora", name) for name in lora_names + ] + logger.info("blending base model %s with LoRA models: %s", model, lora_models) + + # blend and load unet + unet = blend_loras( + server, + unet, + list(zip(lora_models, lora_weights)), + "unet", + xl=params.is_xl(), + ) + + (unet_model, unet_data) = buffer_external_data_tensors(unet) + unet_names, unet_values = zip(*unet_data) + unet_opts = device.sess_options(cache=False) + unet_opts.add_external_initializers(list(unet_names), list(unet_values)) + + if params.is_xl(): + unet_session = InferenceSession( + unet_model.SerializeToString(), + providers=[device.ort_provider("unet")], + sess_options=unet_opts, + ) + unet_session._model_path = path.join(model, "unet") + components["unet_session"] = unet_session + else: + components["unet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + unet_model.SerializeToString(), + provider=device.ort_provider("unet"), + sess_options=unet_opts, ) + ) - if params.is_xl(): - text_encoder_session = InferenceSession( - text_encoder.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_opts, - ) - text_encoder_session._model_path = path.join(model, "text_encoder") - components["text_encoder_session"] = text_encoder_session - else: - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=text_encoder_opts, - ) - ) + return components - if params.is_xl(): - text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) - text_encoder_2 = blend_loras( - server, - text_encoder_2, - list(zip(lora_models, lora_weights)), - "text_encoder", - 2, - params.is_xl(), - ) - (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( - text_encoder_2 - ) - text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) - text_encoder_2_opts = device.sess_options(cache=False) - text_encoder_2_opts.add_external_initializers( - list(text_encoder_2_names), list(text_encoder_2_values) - ) - text_encoder_2_session = InferenceSession( - text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_2_opts, - ) - text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2_session"] = text_encoder_2_session +def load_vae( + _server: ServerContext, device: DeviceParams, model: str, params: ImageParams +): + # one or more VAE models need to be loaded + vae = path.join(model, "vae", ONNX_MODEL) + vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) + vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) - # blend and load unet - unet = path.join(model, unet_type, ONNX_MODEL) - blended_unet = blend_loras( - server, - unet, - list(zip(lora_models, lora_weights)), - "unet", - xl=params.is_xl(), + components = {} + if not params.is_xl() and path.exists(vae): + logger.debug("loading VAE from %s", vae) + components["vae"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), ) - (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) - unet_names, unet_values = zip(*unet_data) - unet_opts = device.sess_options(cache=False) - unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - - if params.is_xl(): - unet_session = InferenceSession( - unet_model.SerializeToString(), - providers=[device.ort_provider("unet")], - sess_options=unet_opts, - ) - unet_session._model_path = path.join(model, "unet") - components["unet_session"] = unet_session - else: - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet_model.SerializeToString(), - provider=device.ort_provider("unet"), - sess_options=unet_opts, - ) - ) - - # make sure a UNet has been loaded - if not params.is_xl() and "unet" not in components: - unet = path.join(model, unet_type, ONNX_MODEL) - logger.debug("loading UNet (%s) from %s", unet_type, unet) - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet, - provider=device.ort_provider("unet"), - sess_options=device.sess_options(), - ) + ) + elif path.exists(vae_decoder) and path.exists(vae_encoder): + if params.is_xl(): + logger.debug("loading VAE decoder from %s", vae_decoder) + components["vae_decoder_session"] = OnnxRuntimeModel.load_model( + vae_decoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), ) + components["vae_decoder_session"]._model_path = vae_decoder - # one or more VAE models need to be loaded - vae = path.join(model, "vae", ONNX_MODEL) - vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) - vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) - - if not params.is_xl() and path.exists(vae): - logger.debug("loading VAE from %s", vae) - components["vae"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - vae, - provider=device.ort_provider("vae"), - sess_options=device.sess_options(), - ) + logger.debug("loading VAE encoder from %s", vae_encoder) + components["vae_encoder_session"] = OnnxRuntimeModel.load_model( + vae_encoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), ) - elif ( - not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder) - ): + components["vae_encoder_session"]._model_path = vae_encoder + + else: logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -382,119 +565,44 @@ def load_pipeline( ) ) - # additional options for panorama pipeline - if params.is_panorama(): - components["window"] = params.tiles // 8 - components["stride"] = params.stride // 8 - - pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) - logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) - pipe = pipeline_class.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - safety_checker=None, - torch_dtype=torch_dtype, - **components, - ) - - # make sure XL models are actually being used - # TODO: why is this needed? - if "text_encoder_session" in components: - logger.info( - "text encoder matches: %s, %s", - pipe.text_encoder.session == components["text_encoder_session"], - type(pipe.text_encoder), - ) - pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder) - - if "text_encoder_2_session" in components: - logger.info( - "text encoder 2 matches: %s, %s", - pipe.text_encoder_2.session == components["text_encoder_2_session"], - type(pipe.text_encoder_2), - ) - pipe.text_encoder_2 = ORTModelTextEncoder( - text_encoder_2_session, text_encoder_2 - ) - - if "unet_session" in components: - logger.info( - "unet matches: %s, %s", - pipe.unet.session == components["unet_session"], - type(pipe.unet), - ) - pipe.unet = ORTModelUnet(unet_session, unet_model) - - if not server.show_progress: - pipe.set_progress_bar_config(disable=True) - - optimize_pipeline(server, pipe) - - if not params.is_xl(): - patch_pipeline(server, pipe, pipeline, pipeline_class, params) - - server.cache.set(ModelTypes.diffusion, pipe_key, pipe) - server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) - - if not params.is_xl() and hasattr(pipe, "vae_decoder"): - pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) - - if not params.is_xl() and hasattr(pipe, "vae_encoder"): - pipe.vae_encoder.set_tiled(tiled=params.tiled_vae) - - # update panorama params - if params.is_panorama(): - latent_window = params.tiles // 8 - latent_stride = params.stride // 8 - - pipe.set_window_size(latent_window, latent_stride) - if hasattr(pipe, "vae_decoder"): - pipe.vae_decoder.set_window_size(latent_window, params.overlap) - if hasattr(pipe, "vae_encoder"): - pipe.vae_encoder.set_window_size(latent_window, params.overlap) - - run_gc([device]) - - return pipe + return components def optimize_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, ) -> None: - if ( - "diffusers-attention-slicing" in server.optimizations - or "diffusers-attention-slicing-auto" in server.optimizations - ): + if server.has_optimization( + "diffusers-attention-slicing" + ) or server.has_optimization("diffusers-attention-slicing-auto"): logger.debug("enabling auto attention slicing on SD pipeline") try: pipe.enable_attention_slicing(slice_size="auto") except Exception as e: logger.warning("error while enabling auto attention slicing: %s", e) - if "diffusers-attention-slicing-max" in server.optimizations: + if server.has_optimization("diffusers-attention-slicing-max"): logger.debug("enabling max attention slicing on SD pipeline") try: pipe.enable_attention_slicing(slice_size="max") except Exception as e: logger.warning("error while enabling max attention slicing: %s", e) - if "diffusers-vae-slicing" in server.optimizations: + if server.has_optimization("diffusers-vae-slicing"): logger.debug("enabling VAE slicing on SD pipeline") try: pipe.enable_vae_slicing() except Exception as e: logger.warning("error while enabling VAE slicing: %s", e) - if "diffusers-cpu-offload-sequential" in server.optimizations: + if server.has_optimization("diffusers-cpu-offload-sequential"): logger.debug("enabling sequential CPU offload on SD pipeline") try: pipe.enable_sequential_cpu_offload() except Exception as e: logger.warning("error while enabling sequential CPU offload: %s", e) - elif "diffusers-cpu-offload-model" in server.optimizations: + elif server.has_optimization("diffusers-cpu-offload-model"): # TODO: check for accelerate logger.debug("enabling model CPU offload on SD pipeline") try: @@ -502,7 +610,7 @@ def optimize_pipeline( except Exception as e: logger.warning("error while enabling model CPU offload: %s", e) - if "diffusers-memory-efficient-attention" in server.optimizations: + if server.has_optimization("diffusers-memory-efficient-attention"): # TODO: check for xformers logger.debug("enabling memory efficient attention for SD pipeline") try: @@ -514,17 +622,17 @@ def optimize_pipeline( def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, - pipe_type: str, pipeline: Any, params: ImageParams, ) -> None: logger.debug("patching SD pipeline") - if pipe_type != "lpw": + if not params.is_lpw() and not params.is_xl(): pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) original_unet = pipe.unet - pipe.unet = UNetWrapper(server, original_unet) + pipe.unet = UNetWrapper(server, original_unet, params.is_xl()) + logger.debug("patched UNet with wrapper") if hasattr(pipe, "vae_decoder"): original_decoder = pipe.vae_decoder @@ -532,18 +640,21 @@ def patch_pipeline( server, original_decoder, decoder=True, - window=params.tiles, - overlap=params.overlap, + window=params.unet_tile, + overlap=params.vae_overlap, ) + logger.debug("patched VAE decoder with wrapper") + + if hasattr(pipe, "vae_encoder"): original_encoder = pipe.vae_encoder pipe.vae_encoder = VAEWrapper( server, original_encoder, decoder=False, - window=params.tiles, - overlap=params.overlap, + window=params.unet_tile, + overlap=params.vae_overlap, ) - elif hasattr(pipe, "vae"): - pass # TODO: current wrapper does not work with upscaling VAE - else: - logger.debug("no VAE found to patch") + logger.debug("patched VAE encoder with wrapper") + + if hasattr(pipe, "vae"): + logger.warning("not patching single VAE, tiled VAE may not work") diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index 9ad49cfc..81065d97 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -14,20 +14,23 @@ class UNetWrapper(object): prompt_index: int = 0 server: ServerContext wrapped: OnnxRuntimeModel + xl: bool def __init__( self, server: ServerContext, wrapped: OnnxRuntimeModel, + xl: bool, ): self.server = server self.wrapped = wrapped + self.xl = xl def __call__( self, - sample: np.ndarray = None, - timestep: np.ndarray = None, - encoder_hidden_states: np.ndarray = None, + sample: Optional[np.ndarray] = None, + timestep: Optional[np.ndarray] = None, + encoder_hidden_states: Optional[np.ndarray] = None, **kwargs, ): logger.trace( @@ -43,13 +46,21 @@ class UNetWrapper(object): encoder_hidden_states = self.prompt_embeds[step_index] self.prompt_index += 1 - if sample.dtype != timestep.dtype: - logger.trace("converting UNet sample to timestep dtype") - sample = sample.astype(timestep.dtype) + if self.xl: + if sample.dtype != encoder_hidden_states.dtype: + logger.trace( + "converting UNet sample to hidden state dtype for XL: %s", + encoder_hidden_states.dtype, + ) + sample = sample.astype(encoder_hidden_states.dtype) + else: + if sample.dtype != timestep.dtype: + logger.trace("converting UNet sample to timestep dtype") + sample = sample.astype(timestep.dtype) - if encoder_hidden_states.dtype != timestep.dtype: - logger.trace("converting UNet hidden states to timestep dtype") - encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) + if encoder_hidden_states.dtype != timestep.dtype: + logger.trace("converting UNet hidden states to timestep dtype") + encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) return self.wrapped( sample=sample, diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index c5fd6936..d7e3e1f7 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -12,8 +12,6 @@ from ...server import ServerContext logger = getLogger(__name__) -LATENT_CHANNELS = 4 - class VAEWrapper(object): def __init__( @@ -39,11 +37,17 @@ class VAEWrapper(object): self.tile_overlap_factor = overlap def __call__(self, latent_sample=None, sample=None, **kwargs): + model = ( + self.wrapped.model + if hasattr(self.wrapped, "model") + else self.wrapped.session + ) + # set timestep dtype to input type sample_dtype = next( ( input.type - for input in self.wrapped.model.get_inputs() + for input in model.get_inputs() if input.name == "sample" or input.name == "latent_sample" ), "tensor(float)", diff --git a/api/onnx_web/diffusers/pipelines/controlnet.py b/api/onnx_web/diffusers/pipelines/controlnet.py index 684b0c5e..e0515dc5 100644 --- a/api/onnx_web/diffusers/pipelines/controlnet.py +++ b/api/onnx_web/diffusers/pipelines/controlnet.py @@ -13,8 +13,8 @@ import numpy as np import PIL import torch from diffusers.configuration_utils import FrozenDict -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import PIL_INTERPOLATION, deprecate, logging diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 99e283a1..317a8515 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -13,25 +13,36 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from math import ceil +from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL import torch from diffusers.configuration_utils import FrozenDict -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPImageProcessor, CLIPTokenizer +from ...chain.tile import make_tile_mask +from ...constants import LATENT_CHANNELS, LATENT_FACTOR +from ...params import Size +from ..utils import ( + expand_latents, + parse_regions, + random_seed, + repair_nan, + resize_latent_shape, +) + logger = logging.get_logger(__name__) # inpaint constants NUM_UNET_INPUT_CHANNELS = 9 -NUM_LATENT_CHANNELS = 4 DEFAULT_WINDOW = 32 DEFAULT_STRIDE = 8 @@ -346,13 +357,16 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): f" {negative_prompt_embeds.shape}." ) - def get_views(self, panorama_height, panorama_width, window_size, stride): + def get_views( + self, panorama_height: int, panorama_width: int, window_size: int, stride: int + ) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]: # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) panorama_height /= 8 panorama_width /= 8 - num_blocks_height = abs((panorama_height - window_size) // stride) + 1 - num_blocks_width = abs((panorama_width - window_size) // stride) + 1 + num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1 + num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) logger.debug( "panorama generated %s views, %s by %s blocks", @@ -369,7 +383,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) - return views + return (views, (h_end * 8, w_end * 8)) @torch.no_grad() def text2img( @@ -479,6 +493,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + prompt, regions = parse_regions(prompt) + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, @@ -488,9 +504,30 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, ) + # 3.b. Encode region prompts + region_embeds: List[np.ndarray] = [] + + for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions: + if region_prompt.endswith("+"): + region_prompt = region_prompt[:-1] + " " + prompt + + region_prompt_embeds = self._encode_prompt( + region_prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + region_embeds.append(region_prompt_embeds) + # get the initial random noise unless the user supplied it latents_dtype = prompt_embeds.dtype - latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + latents_shape = ( + batch_size * num_images_per_prompt, + LATENT_CHANNELS, + height // LATENT_FACTOR, + width // LATENT_FACTOR, + ) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: @@ -525,11 +562,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # panorama additions - views = self.get_views(height, width, self.window, self.stride) - count = np.zeros_like(latents) - value = np.zeros_like(latents) + views, resize = self.get_views(height, width, self.window, self.stride) + logger.trace("panorama resized latents to %s", resize) + + count = np.zeros(resize_latent_shape(latents, resize)) + value = np.zeros(resize_latent_shape(latents, resize)) + + # adjust latents + latents = expand_latents( + latents, + random_seed(generator), + Size(resize[1], resize[0]), + sigma=self.scheduler.init_noise_sigma, + ) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + last = i == (len(self.scheduler.timesteps) - 1) count.fill(0) value.fill(0) @@ -576,13 +624,115 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 + if not last: + for r, region in enumerate(regions): + top, left, bottom, right, weight, feather, prompt = region + logger.debug( + "running region prompt: %s, %s, %s, %s, %s, %s, %s", + top, + left, + bottom, + right, + weight, + feather, + prompt, + ) + + # convert coordinates to latent space + h_start = top // LATENT_FACTOR + h_end = bottom // LATENT_FACTOR + w_start = left // LATENT_FACTOR + w_end = right // LATENT_FACTOR + + # get the latents corresponding to the current view coordinates + latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, + ) + + # expand the latents if we are doing classifier free guidance + latent_region_input = ( + np.concatenate([latents_for_region] * 2) + if do_classifier_free_guidance + else latents_for_region + ) + latent_region_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_region_input), t + ) + latent_region_input = latent_region_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + region_noise_pred = self.unet( + sample=latent_region_input, + timestep=timestep, + encoder_hidden_states=region_embeds[r], + ) + region_noise_pred = region_noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + region_noise_pred_uncond, region_noise_pred_text = np.split( + region_noise_pred, 2 + ) + region_noise_pred = ( + region_noise_pred_uncond + + guidance_scale + * (region_noise_pred_text - region_noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(region_noise_pred), + t, + torch.from_numpy(latents_for_region), + **extra_step_kwargs, + ) + latents_region_denoised = scheduler_output.prev_sample.numpy() + + if feather[0] > 0.0: + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather[0], + feather[1], + ) + mask = np.expand_dims(mask, axis=0) + mask = np.repeat(mask, 4, axis=0) + mask = np.expand_dims(mask, axis=0) + else: + mask = 1 + + if weight >= 100.0: + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) + count[:, :, h_start:h_end, w_start:w_end] = mask + else: + value[:, :, h_start:h_end, w_start:w_end] += ( + latents_region_denoised * weight * mask + ) + count[:, :, h_start:h_end, w_start:w_end] += weight * mask + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) + latents = repair_nan(latents) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # remove extra margins + latents = latents[ + :, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR) + ] + + latents = np.clip(latents, -4, +4) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 @@ -828,9 +978,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # panorama additions - views = self.get_views(height, width, self.window, self.stride) - count = np.zeros_like(latents) - value = np.zeros_like(latents) + views, resize = self.get_views(height, width, self.window, self.stride) + logger.trace("panorama resized latents to %s", resize) + + count = np.zeros(resize_latent_shape(latents, resize)) + value = np.zeros(resize_latent_shape(latents, resize)) + + # adjust latents + latents = expand_latents( + latents, + random_seed(generator), + Size(resize[1], resize[0]), + sigma=self.scheduler.init_noise_sigma, + ) for i, t in enumerate(self.progress_bar(timesteps)): count.fill(0) @@ -886,6 +1046,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # remove extra margins + latents = latents[ + :, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR) + ] + latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 @@ -1053,12 +1218,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, ) - num_channels_latents = NUM_LATENT_CHANNELS + num_channels_latents = LATENT_CHANNELS latents_shape = ( batch_size * num_images_per_prompt, num_channels_latents, - height // 8, - width // 8, + height // LATENT_FACTOR, + width // LATENT_FACTOR, ) latents_dtype = prompt_embeds.dtype if latents is None: @@ -1136,9 +1301,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # panorama additions - views = self.get_views(height, width, self.window, self.stride) - count = np.zeros_like(latents) - value = np.zeros_like(latents) + views, resize = self.get_views(height, width, self.window, self.stride) + logger.trace("panorama resized latents to %s", resize) + + count = np.zeros(resize_latent_shape(latents, resize)) + value = np.zeros(resize_latent_shape(latents, resize)) + + # adjust latents + latents = expand_latents( + latents, + random_seed(generator), + Size(resize[1], resize[0]), + sigma=self.scheduler.init_noise_sigma, + ) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): count.fill(0) @@ -1201,6 +1376,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # remove extra margins + latents = latents[ + :, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR) + ] + latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py new file mode 100644 index 00000000..5551267b --- /dev/null +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -0,0 +1,955 @@ +import inspect +import logging +from math import ceil +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase +from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import ( + StableDiffusionXLImg2ImgPipelineMixin, +) +from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg + +from ...chain.tile import make_tile_mask +from ...constants import LATENT_FACTOR +from ...params import Size +from ..utils import ( + expand_latents, + parse_regions, + random_seed, + repair_nan, + resize_latent_shape, +) + +logger = logging.getLogger(__name__) + + +DEFAULT_WINDOW = 64 +DEFAULT_STRIDE = 16 + + +class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin): + def __init__( + self, + *args, + window: int = DEFAULT_WINDOW, + stride: int = DEFAULT_STRIDE, + **kwargs, + ): + super().__init__(self, *args, **kwargs) + + self.window = window + self.stride = stride + + def set_window_size(self, window: int, stride: int): + self.window = window + self.stride = stride + + def get_views( + self, panorama_height: int, panorama_width: int, window_size: int, stride: int + ) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]: + # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) + panorama_height /= 8 + panorama_width /= 8 + + num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1 + num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1 + + total_num_blocks = int(num_blocks_height * num_blocks_width) + logger.debug( + "panorama generated %s views, %s by %s blocks", + total_num_blocks, + num_blocks_height, + num_blocks_width, + ) + + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + + return (views, (h_end * 8, w_end * 8)) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None + ): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[ + 0 + ] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate( + [init_latents] * additional_image_per_prompt, axis=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), + torch.from_numpy(noise), + torch.from_numpy(timestep), + ) + return init_latents.numpy() + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents_text2img( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {shape}" + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + return latents + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + extra_step_kwargs = {} + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + return extra_step_kwargs + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def text2img( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`Optional[int]`, defaults to None): + The height in pixels of the generated image. + width (`Optional[int]`, defaults to None): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + 1.0, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt, regions = parse_regions(prompt) + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 3.b. Encode region prompts + region_embeds: List[np.ndarray] = [] + add_region_embeds: List[np.ndarray] = [] + + for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions: + if region_prompt.endswith("+"): + region_prompt = region_prompt[:-1] + " " + prompt + + ( + region_prompt_embeds, + region_negative_prompt_embeds, + region_pooled_prompt_embeds, + region_negative_pooled_prompt_embeds, + ) = self._encode_prompt( + region_prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + if do_classifier_free_guidance: + region_prompt_embeds = np.concatenate( + (region_negative_prompt_embeds, region_prompt_embeds), axis=0 + ) + add_region_embeds.append( + np.concatenate( + ( + region_negative_pooled_prompt_embeds, + region_pooled_prompt_embeds, + ), + axis=0, + ) + ) + + region_embeds.append(region_prompt_embeds) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + latents = self.prepare_latents_text2img( + batch_size * num_images_per_prompt, + self.unet.config.get("in_channels", 4), + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate( + (negative_prompt_embeds, prompt_embeds), axis=0 + ) + add_text_embeds = np.concatenate( + (negative_pooled_prompt_embeds, add_text_embeds), axis=0 + ) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + + add_time_ids = np.repeat( + add_time_ids, batch_size * num_images_per_prompt, axis=0 + ) + + # Adapted from diffusers to extend it for other runtimes than ORT + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + # 8. Panorama additions + views, resize = self.get_views(height, width, self.window, self.stride) + logger.trace("panorama resized latents to %s", resize) + + count = np.zeros(resize_latent_shape(latents, resize)) + value = np.zeros(resize_latent_shape(latents, resize)) + + # adjust latents + latents = expand_latents( + latents, + random_seed(generator), + Size(resize[1], resize[0]), + sigma=self.scheduler.init_noise_sigma, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + last = i == (len(timesteps) - 1) + count.fill(0) + value.fill(0) + + for h_start, h_end, w_start, w_end in views: + # get the latents corresponding to the current view coordinates + latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + np.concatenate([latents_for_view] * 2) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_model_input), t + ) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), + t, + torch.from_numpy(latents_for_view), + **extra_step_kwargs, + ) + latents_view_denoised = scheduler_output.prev_sample.numpy() + + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + if not last: + for r, region in enumerate(regions): + top, left, bottom, right, weight, feather, prompt = region + logger.debug( + "running region prompt: %s, %s, %s, %s, %s, %s, %s", + top, + left, + bottom, + right, + weight, + feather, + prompt, + ) + + # convert coordinates to latent space + h_start = top // LATENT_FACTOR + h_end = bottom // LATENT_FACTOR + w_start = left // LATENT_FACTOR + w_end = right // LATENT_FACTOR + + # get the latents corresponding to the current view coordinates + latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, + ) + + # expand the latents if we are doing classifier free guidance + latent_region_input = ( + np.concatenate([latents_for_region] * 2) + if do_classifier_free_guidance + else latents_for_region + ) + latent_region_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_region_input), t + ) + latent_region_input = latent_region_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + region_noise_pred = self.unet( + sample=latent_region_input, + timestep=timestep, + encoder_hidden_states=region_embeds[r], + text_embeds=add_region_embeds[r], + time_ids=add_time_ids, + ) + region_noise_pred = region_noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + region_noise_pred_uncond, region_noise_pred_text = np.split( + region_noise_pred, 2 + ) + region_noise_pred = ( + region_noise_pred_uncond + + guidance_scale + * (region_noise_pred_text - region_noise_pred_uncond) + ) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + region_noise_pred = rescale_noise_cfg( + region_noise_pred, + region_noise_pred_text, + guidance_rescale=guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(region_noise_pred), + t, + torch.from_numpy(latents_for_region), + **extra_step_kwargs, + ) + latents_region_denoised = scheduler_output.prev_sample.numpy() + + if feather[0] > 0.0: + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather[0], + feather[1], + ) + mask = np.expand_dims(mask, axis=0) + mask = np.repeat(mask, 4, axis=0) + mask = np.expand_dims(mask, axis=0) + else: + mask = 1 + + if weight >= 100.0: + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) + count[:, :, h_start:h_end, w_start:w_end] = mask + else: + value[:, :, h_start:h_end, w_start:w_end] += ( + latents_region_denoised * weight * mask + ) + count[:, :, h_start:h_end, w_start:w_end] += weight * mask + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 + latents = np.where(count > 0, value / count, value) + latents = repair_nan(latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # remove extra margins + latents = latents[ + :, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR) + ] + + if output_type == "latent": + image = latents + else: + latents = np.clip(latents, -4, +4) + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [ + self.vae_decoder(latent_sample=latents[i : i + 1])[0] + for i in range(latents.shape[0]) + ] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def img2img( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.3, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`Union[np.ndarray, PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be upscaled. + strength (`float`, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 3. Preprocess image + processor = VaeImageProcessor() + image = processor.preprocess(image).cpu().numpy() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength + ) + latent_timestep = np.repeat( + timesteps[:1], batch_size * num_images_per_prompt, axis=0 + ) + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # 5. Prepare latent variables + latents = self.prepare_latents_img2img( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + latents_dtype, + generator, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate( + (negative_prompt_embeds, prompt_embeds), axis=0 + ) + add_text_embeds = np.concatenate( + (negative_pooled_prompt_embeds, add_text_embeds), axis=0 + ) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat( + add_time_ids, batch_size * num_images_per_prompt, axis=0 + ) + + # 8. Panorama additions + views, resize = self.get_views(height, width, self.window, self.stride) + logger.trace("panorama resized latents to %s", resize) + + count = np.zeros(resize_latent_shape(latents, resize)) + value = np.zeros(resize_latent_shape(latents, resize)) + + latents = expand_latents( + latents, + random_seed(generator), + Size(resize[1], resize[0]), + sigma=self.scheduler.init_noise_sigma, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + count.fill(0) + value.fill(0) + + for h_start, h_end, w_start, w_end in views: + # get the latents corresponding to the current view coordinates + latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + np.concatenate([latents_for_view] * 2) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_model_input), t + ) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), + t, + torch.from_numpy(latents_for_view), + **extra_step_kwargs, + ) + latents_view_denoised = scheduler_output.prev_sample.numpy() + + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 + latents = np.where(count > 0, value / count, value) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # remove extra margins + latents = latents[ + :, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR) + ] + + if output_type == "latent": + image = latents + else: + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [ + self.vae_decoder(latent_sample=latents[i : i + 1])[0] + for i in range(latents.shape[0]) + ] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) + + def __call__( + self, + *args, + **kwargs, + ): + if "image" in kwargs or ( + len(args) > 1 + and ( + isinstance(args[1], np.ndarray) or isinstance(args[1], PIL.Image.Image) + ) + ): + logger.debug("running img2img panorama XL pipeline") + return self.img2img(*args, **kwargs) + else: + logger.debug("running txt2img panorama XL pipeline") + return self.text2img(*args, **kwargs) + + +class ORTStableDiffusionXLPanoramaPipeline( + ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin +): + def __call__(self, *args, **kwargs): + return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs) diff --git a/api/onnx_web/diffusers/pipelines/pix2pix.py b/api/onnx_web/diffusers/pipelines/pix2pix.py index ffcd7e9b..2689fa4f 100644 --- a/api/onnx_web/diffusers/pipelines/pix2pix.py +++ b/api/onnx_web/diffusers/pipelines/pix2pix.py @@ -32,7 +32,7 @@ except ImportError: } from diffusers import OnnxRuntimeModel -from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import ( DDIMScheduler, diff --git a/api/onnx_web/diffusers/pipelines/upscale.py b/api/onnx_web/diffusers/pipelines/upscale.py index deb2c8c3..aa07cc99 100644 --- a/api/onnx_web/diffusers/pipelines/upscale.py +++ b/api/onnx_web/diffusers/pipelines/upscale.py @@ -1,63 +1,25 @@ -### -# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline: -# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py -# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py -# See also: https://github.com/huggingface/diffusers/pull/2158 -### - from logging import getLogger -from typing import Any, Callable, List, Optional, Union +from typing import Any, List -import numpy as np -import PIL -import torch -from diffusers.pipeline_utils import ImagePipelineOutput -from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline +from diffusers.pipelines.onnx_utils import OnnxRuntimeModel +from diffusers.pipelines.stable_diffusion import ( + OnnxStableDiffusionUpscalePipeline as BasePipeline, +) from diffusers.schedulers import DDPMScheduler logger = getLogger(__name__) -NUM_LATENT_CHANNELS = 4 -NUM_UNET_INPUT_CHANNELS = 7 - -ORT_TO_PT_TYPE = { - "float16": torch.float16, - "float32": torch.float32, -} - - -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - return image - - class FakeConfig: + block_out_channels: List[int] scaling_factor: float def __init__(self) -> None: + self.block_out_channels = [128, 256, 512] self.scaling_factor = 0.08333 -class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): +class OnnxStableDiffusionUpscalePipeline(BasePipeline): def __init__( self, vae: OnnxRuntimeModel, @@ -80,260 +42,3 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): scheduler, max_noise_level=max_noise_level, ) - - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], - num_inference_steps: int = 75, - guidance_scale: float = 9.0, - noise_level: int = 20, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - ): - # 1. Check inputs - self.check_inputs(prompt, image, noise_level, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, - # device, device only needed for Torch pipelines - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - ) - - latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] - - # 4. Preprocess image - image = preprocess(image) - image = image.cpu() - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Add noise to image - noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = torch.randn( - image.shape, generator=generator, device=device, dtype=latents_dtype - ) - image = self.low_res_scheduler.add_noise(image, noise, noise_level) - - batch_multiplier = 2 if do_classifier_free_guidance else 1 - image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) - noise_level = np.concatenate([noise_level] * image.shape[0]) - - # 6. Prepare latent variables - height, width = image.shape[2:] - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - NUM_LATENT_CHANNELS, - height, - width, - latents_dtype, - device, - generator, - latents, - ) - - # 7. Check that sizes of image and latents match - num_channels_image = image.shape[1] - if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS: - raise ValueError( - "Incorrect configuration settings! The config of `pipeline.unet` expects" - f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +" - f" `num_channels_image`: {num_channels_image} " - f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - timestep_dtype = next( - ( - input.type - for input in self.unet.model.get_inputs() - if input.name == "timestep" - ), - "tensor(float)", - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - np.concatenate([latents] * 2) - if do_classifier_free_guidance - else latents - ) - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - latent_model_input = np.concatenate([latent_model_input, image], axis=1) - - # timestep to tensor - timestep = np.array([t], dtype=timestep_dtype) - - # predict the noise residual - noise_pred = self.unet( - sample=latent_model_input, - timestep=timestep, - encoder_hidden_states=text_embeddings, - class_labels=noise_level.astype(np.int64), - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs - ).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 10. Post-processing - image = self.decode_latents(latents.float()) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) - - def decode_latents(self, latents): - latents = 1 / 0.08333 * latents - image = self.vae(latent_sample=latents)[0] - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - return image - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - ): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - # no positional arguments to text_encoder - text_embeddings = self.text_encoder( - input_ids=text_input_ids.int().to(device), - ) - text_embeddings = text_embeddings[0] - - bs_embed, seq_len, _ = text_embeddings.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) - text_embeddings = text_embeddings.reshape( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - uncond_embeddings = self.text_encoder( - input_ids=uncond_input.input_ids.int().to(device), - ) - uncond_embeddings = uncond_embeddings[0] - - seq_len = uncond_embeddings.shape[1] - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) - uncond_embeddings = uncond_embeddings.reshape( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) - - return text_embeddings diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index b1f14b1d..5b159019 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -4,15 +4,16 @@ from typing import Any, List, Optional from PIL import Image, ImageOps -from onnx_web.chain.highres import stage_highres - from ..chain import ( + BlendDenoiseStage, BlendImg2ImgStage, BlendMaskStage, ChainPipeline, SourceTxt2ImgStage, UpscaleOutpaintStage, ) +from ..chain.highres import stage_highres +from ..chain.result import StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image from ..output import save_image @@ -33,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt logger = getLogger(__name__) +def get_base_tile(params: ImageParams, size: Size) -> int: + if params.is_panorama(): + tile = max(params.unet_tile, size.width, size.height) + logger.debug("adjusting tile size for panorama to %s", tile) + return tile + + return params.unet_tile + + +def get_highres_tile( + server: ServerContext, params: ImageParams, highres: HighresParams, tile: int +) -> int: + if params.is_panorama() and server.has_feature("panorama-highres"): + return tile * highres.scale + + return params.unet_tile + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, @@ -43,10 +62,7 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: # if using panorama, the pipeline will tile itself (views) - if params.is_panorama() or params.is_xl(): - tile_size = max(params.tiles, size.width, size.height) - else: - tile_size = params.tiles + tile_size = get_base_tile(params, size) # prepare the chain pipeline and first stage chain = ChainPipeline() @@ -57,15 +73,21 @@ def run_txt2img_pipeline( ), size=size, prompt_index=0, - overlap=params.overlap, + overlap=params.vae_overlap, ) # apply upscaling and correction, before highres - stage = StageParams(tile_size=params.tiles) + highres_size = get_highres_tile(server, params, highres, tile_size) + if params.is_panorama(): + chain.stage( + BlendDenoiseStage(), + StageParams(tile_size=highres_size), + ) + first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=highres_size), params, chain=chain, upscale=first_upscale, @@ -73,7 +95,7 @@ def run_txt2img_pipeline( # apply highres stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -83,7 +105,7 @@ def run_txt2img_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale, tile_size=highres_size), params, chain=chain, upscale=after_upscale, @@ -92,11 +114,14 @@ def run_txt2img_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run(worker, server, params, [], callback=progress, latents=latents) + images = chain.run( + worker, server, params, StageResult.empty(), callback=progress, latents=latents + ) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): + logger.trace("saving output image %s: %s", output, image.size) dest = save_image( server, output, @@ -136,23 +161,26 @@ def run_img2img_pipeline( source = f(server, source) # prepare the chain pipeline and first stage + tile_size = get_base_tile(params, Size(*source.size)) chain = ChainPipeline() - stage = StageParams( - tile_size=params.tiles, - ) chain.stage( BlendImg2ImgStage(), - stage, + StageParams( + tile_size=tile_size, + ), prompt_index=0, strength=strength, - overlap=params.overlap, + overlap=params.vae_overlap, ) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams( + outscale=first_upscale.outscale, + tile_size=tile_size, + ), params, upscale=first_upscale, chain=chain, @@ -162,13 +190,16 @@ def run_img2img_pipeline( for _i in range(params.loopback): chain.stage( BlendImg2ImgStage(), - stage, + StageParams( + tile_size=tile_size, + ), strength=strength, ) # highres, if selected + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(tile_size=highres_size, outscale=highres.scale), params, highres, upscale, @@ -178,7 +209,7 @@ def run_img2img_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(tile_size=tile_size, outscale=after_upscale.scale), params, upscale=after_upscale, chain=chain, @@ -186,7 +217,9 @@ def run_img2img_pipeline( # run and append the filtered source progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress) + images = chain.run( + worker, server, params, StageResult(images=[source]), callback=progress + ) if source_filter is not None and source_filter != "none": images.append(source) @@ -235,7 +268,7 @@ def run_inpaint_pipeline( full_res_inpaint_padding: float, ) -> None: logger.debug("building inpaint pipeline") - tile_size = params.tiles + tile_size = get_base_tile(params, size) if mask is None: # if no mask was provided, keep the full source image @@ -264,8 +297,12 @@ def run_inpaint_pipeline( logger.debug("border zero: %s", border.isZero()) full_res_inpaint = full_res_inpaint and border.isZero() if full_res_inpaint: - mask_left, mask_top, mask_right, mask_bottom = mask.getbbox() - logger.debug("mask bbox: %s", mask.getbbox()) + bbox = mask.getbbox() + if bbox is None: + bbox = (0, 0, source.width, source.height) + + logger.debug("mask bounding box: %s", bbox) + mask_left, mask_top, mask_right, mask_bottom = bbox mask_width = mask_right - mask_left mask_height = mask_bottom - mask_top # ensure we have some padding around the mask when we do the inpaint (and that the region size is even) @@ -322,16 +359,15 @@ def run_inpaint_pipeline( # set up the chain pipeline and base stage chain = ChainPipeline() - stage = StageParams(tile_order=tile_order, tile_size=tile_size) chain.stage( UpscaleOutpaintStage(), - stage, + StageParams(tile_order=tile_order, tile_size=tile_size), border=border, mask=mask, fill_color=fill_color, mask_filter=mask_filter, noise_source=noise_source, - overlap=params.overlap, + overlap=params.vae_overlap, prompt_index=0, ) @@ -339,15 +375,16 @@ def run_inpaint_pipeline( first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=tile_size), params, upscale=first_upscale, chain=chain, ) # apply highres + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -357,7 +394,7 @@ def run_inpaint_pipeline( # apply upscaling and correction stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale), params, upscale=after_upscale, chain=chain, @@ -366,7 +403,14 @@ def run_inpaint_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress, latents=latents) + images = chain.run( + worker, + server, + params, + StageResult(images=[source]), + callback=progress, + latents=latents, + ) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -409,21 +453,22 @@ def run_upscale_pipeline( ) -> None: # set up the chain pipeline, no base stage for upscaling chain = ChainPipeline() - stage = StageParams(tile_size=params.tiles) + tile_size = get_base_tile(params, size) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=tile_size), params, upscale=first_upscale, chain=chain, ) # apply highres + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -433,7 +478,7 @@ def run_upscale_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale, tile_size=tile_size), params, upscale=after_upscale, chain=chain, @@ -441,7 +486,9 @@ def run_upscale_pipeline( # run and save progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress) + images = chain.run( + worker, server, params, StageResult(images=[source]), callback=progress + ) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -478,12 +525,18 @@ def run_blend_pipeline( ) -> None: # set up the chain pipeline and base stage chain = ChainPipeline() - stage = StageParams() - chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask) + tile_size = get_base_tile(params, size) + + chain.stage( + BlendMaskStage(), + StageParams(tile_size=tile_size), + stage_source=sources[1], + stage_mask=mask, + ) # apply upscaling and correction stage_upscale_correction( - stage, + StageParams(outscale=upscale.outscale), params, upscale=upscale, chain=chain, @@ -491,7 +544,9 @@ def run_blend_pipeline( # run and save progress = worker.get_progress_callback() - images = chain(worker, server, params, sources, callback=progress) + images = chain.run( + worker, server, params, StageResult(images=sources), callback=progress + ) for image, output in zip(images, outputs): dest = save_image(server, output, image, params, size, upscale=upscale) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index ab3c63c5..b9dc4394 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,23 +3,27 @@ from copy import deepcopy from logging import getLogger from math import ceil from re import Pattern, compile -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from diffusers import OnnxStableDiffusionPipeline +from ..constants import LATENT_CHANNELS, LATENT_FACTOR from ..params import ImageParams, Size logger = getLogger(__name__) -LATENT_CHANNELS = 4 -LATENT_FACTOR = 8 MAX_TOKENS_PER_GROUP = 77 +ANY_TOKEN = compile(r"\<([^\>]*)\>") CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") +REGION_TOKEN = compile( + r"\]+)\>" +) +RESEED_TOKEN = compile(r"\") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") @@ -84,8 +88,8 @@ def expand_prompt( negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, - skip_clip_states: Optional[int] = 0, -) -> "np.NDArray": + skip_clip_states: int = 0, +) -> np.ndarray: # self provides: # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel @@ -140,6 +144,7 @@ def expand_prompt( last_state, _pooled_output, *hidden_states = text_result if skip_clip_states > 0: + # TODO: why is this normalized? layer_norm = torch.nn.LayerNorm(last_state.shape[2]) norm_state = layer_norm( torch.from_numpy( @@ -219,20 +224,25 @@ def expand_prompt( return prompt_embeds +def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]: + name, weight = group + return (name, float(weight)) + + def get_tokens_from_prompt( - prompt: str, pattern: Pattern + prompt: str, + pattern: Pattern, + parser=parse_float_group, ) -> Tuple[str, List[Tuple[str, float]]]: - """ - TODO: replace with Arpeggio - """ remaining_prompt = prompt tokens = [] next_match = pattern.search(remaining_prompt) while next_match is not None: logger.debug("found token in prompt: %s", next_match) - name, weight = next_match.groups() - tokens.append((name, float(weight))) + group = next_match.groups() + tokens.append(parser(group)) + # remove this match and look for another remaining_prompt = ( remaining_prompt[: next_match.start()] @@ -251,6 +261,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float] return get_tokens_from_prompt(prompt, INVERSION_TOKEN) +def random_seed(generator=None) -> int: + if generator is None: + generator = np.random + + return generator.randint(np.iinfo(np.int32).max) + + def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: """ From https://www.travelneil.com/stable-diffusion-updates.html. @@ -266,6 +283,25 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: return image_latents +def expand_latents( + latents: np.ndarray, + seed: int, + size: Size, + sigma: float = 1.0, +) -> np.ndarray: + batch, _channels, height, width = latents.shape + extra_latents = get_latents_from_seed(seed, size, batch=batch) + extra_latents[:, :, 0:height, 0:width] = latents + return extra_latents * np.float64(sigma) + + +def resize_latent_shape( + latents: np.ndarray, + size: Tuple[int, int], +) -> Tuple[int, int, int, int]: + return (latents.shape[0], latents.shape[1], *size) + + def get_tile_latents( full_latents: np.ndarray, seed: int, @@ -290,14 +326,8 @@ def get_tile_latents( tile_latents = full_latents[:, :, y:yt, x:xt] - if tile_latents.shape != full_latents.shape and ( - tile_latents.shape[2] < t or tile_latents.shape[3] < t - ): - extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0]) - extra_latents[ - :, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3] - ] = tile_latents - tile_latents = extra_latents + if tile_latents.shape[2] < t or tile_latents.shape[3] < t: + tile_latents = expand_latents(tile_latents, seed, size) return tile_latents @@ -369,12 +399,15 @@ def encode_prompt( num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, ) -> List[np.ndarray]: + """ + TODO: does not work with SDXL, fix or turn into a pipeline patch + """ return [ pipe._encode_prompt( - prompt, + remove_tokens(prompt), num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=neg_prompt, + negative_prompt=remove_tokens(neg_prompt), ) for prompt, neg_prompt in prompt_pairs ] @@ -444,3 +477,71 @@ def slice_prompt(prompt: str, slice: int) -> str: return parts[min(slice, len(parts) - 1)] else: return prompt + + +Region = Tuple[ + int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str +] + + +def parse_region_group(group: Tuple[str, ...]) -> Region: + top, left, bottom, right, weight, feather, prompt = group + + # break down the feather section + feather_radius, *feather_edges = feather.split("_") + if len(feather_edges) == 0: + feather_edges = "TLBR" + else: + feather_edges = "".join(feather_edges) + + return ( + int(top), + int(left), + int(bottom), + int(right), + float(weight), + ( + float(feather_radius), + ( + "T" in feather_edges, + "L" in feather_edges, + "B" in feather_edges, + "R" in feather_edges, + ), + ), + prompt, + ) + + +def parse_regions(prompt: str) -> Tuple[str, List[Region]]: + return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group) + + +Reseed = Tuple[int, int, int, int, int] + + +def parse_reseed_group(group) -> Region: + top, left, bottom, right, seed = group + return ( + int(top), + int(left), + int(bottom), + int(right), + int(seed), + ) + + +def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]: + return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group) + + +def skip_group(group) -> Any: + return group + + +def remove_tokens(prompt: Optional[str]) -> Optional[str]: + if prompt is None: + return prompt + + remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group) + return remainder diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py index 79198b81..d256d615 100644 --- a/api/onnx_web/diffusers/version_safe_diffusers.py +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -12,6 +12,11 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler +try: + from diffusers import LCMScheduler +except ImportError: + from ..diffusers.stub_scheduler import StubScheduler as LCMScheduler + try: from diffusers import UniPCMultistepScheduler except ImportError: diff --git a/api/onnx_web/image/mask_filter.py b/api/onnx_web/image/mask_filter.py index 82a19dfa..967fce1f 100644 --- a/api/onnx_web/image/mask_filter.py +++ b/api/onnx_web/image/mask_filter.py @@ -8,7 +8,7 @@ def mask_filter_none( ) -> Image.Image: width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(mask.mode, (width, height), fill) noise.paste(mask, origin) return noise diff --git a/api/onnx_web/image/noise_source.py b/api/onnx_web/image/noise_source.py index c543aa6e..2c260f14 100644 --- a/api/onnx_web/image/noise_source.py +++ b/api/onnx_web/image/noise_source.py @@ -17,21 +17,21 @@ def noise_source_fill_edge( """ width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(source.mode, (width, height), fill) noise.paste(source, origin) return noise def noise_source_fill_mask( - _source: Image.Image, dims: Point, _origin: Point, fill="white", **kw + source: Image.Image, dims: Point, _origin: Point, fill="white", **kw ) -> Image.Image: """ Fill the whole canvas, no source or noise. """ width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(source.mode, (width, height), fill) return noise @@ -52,7 +52,7 @@ def noise_source_gaussian( def noise_source_uniform( - _source: Image.Image, dims: Point, _origin: Point, **kw + source: Image.Image, dims: Point, _origin: Point, **kw ) -> Image.Image: width, height = dims size = width * height @@ -61,6 +61,7 @@ def noise_source_uniform( noise_g = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -68,11 +69,11 @@ def noise_source_uniform( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) - return noise + return noise.convert(source.mode) def noise_source_normal( - _source: Image.Image, dims: Point, _origin: Point, **kw + source: Image.Image, dims: Point, _origin: Point, **kw ) -> Image.Image: width, height = dims size = width * height @@ -81,6 +82,7 @@ def noise_source_normal( noise_g = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -88,13 +90,13 @@ def noise_source_normal( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) - return noise + return noise.convert(source.mode) def noise_source_histogram( source: Image.Image, dims: Point, _origin: Point, **kw ) -> Image.Image: - r, g, b = source.split() + r, g, b, *_a = source.split() width, height = dims size = width * height @@ -112,6 +114,7 @@ def noise_source_histogram( 256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size ) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -119,4 +122,4 @@ def noise_source_histogram( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) - return noise + return noise.convert(source.mode) diff --git a/api/onnx_web/image/source_filter.py b/api/onnx_web/image/source_filter.py index ea6e0d12..99c6b5eb 100644 --- a/api/onnx_web/image/source_filter.py +++ b/api/onnx_web/image/source_filter.py @@ -47,7 +47,7 @@ def source_filter_noise( source: Image.Image, strength: float = 0.5, ): - noise = noise_source_histogram(source, source.size) + noise = noise_source_histogram(source, source.size, (0, 0)) return ImageChops.blend(source, noise, strength) diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py index 80972080..a1264ff0 100644 --- a/api/onnx_web/image/utils.py +++ b/api/onnx_web/image/utils.py @@ -1,3 +1,5 @@ +from typing import Tuple + from PIL import Image, ImageChops from ..params import Border, Size @@ -13,12 +15,12 @@ def expand_image( fill="white", noise_source=noise_source_histogram, mask_filter=mask_filter_none, -): +) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]: size = Size(*source.size).add_border(expand) size = tuple(size) origin = (expand.left, expand.top) - full_source = Image.new("RGB", size, fill) + full_source = Image.new(source.mode, size, fill) full_source.paste(source, origin) # new mask pixels need to be filled with white so they will be replaced diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 9e47cc27..6de5dc39 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -23,6 +23,7 @@ from .server.load import ( load_platforms, load_wildcards, ) +from .server.plugin import load_plugins, register_plugins from .server.static import register_static_routes from .server.utils import check_paths from .utils import is_debug @@ -43,15 +44,32 @@ def main(): server = ServerContext.from_environ() apply_patches(server) check_paths(server) + + # debug options + if server.debug: + import debugpy + + debugpy.listen(5678) + logger.warning("waiting for debugger") + debugpy.wait_for_client() + gc.set_debug(gc.DEBUG_STATS) + + # register plugins + exports = load_plugins(server) + success = register_plugins(exports) + if success: + logger.info("all plugins loaded successfully") + else: + logger.warning("error loading plugins") + + # load additional resources load_extras(server) load_models(server) load_params(server) load_platforms(server) load_wildcards(server) - if is_debug(): - gc.set_debug(gc.DEBUG_STATS) - + # misc server options if not server.show_progress: disable_progress_bar() disable_progress_bars() diff --git a/api/onnx_web/models/meta.py b/api/onnx_web/models/meta.py index dcd43c25..6aaa4e52 100644 --- a/api/onnx_web/models/meta.py +++ b/api/onnx_web/models/meta.py @@ -1,18 +1,21 @@ -from typing import Literal +from typing import List, Literal -NetworkType = Literal["inversion", "lora"] +NetworkType = Literal["control", "inversion", "lora"] class NetworkModel: name: str + tokens: List[str] type: NetworkType - def __init__(self, name: str, type: NetworkType) -> None: + def __init__(self, name: str, type: NetworkType, tokens=None) -> None: self.name = name + self.tokens = tokens or [] self.type = type def tojson(self): return { "name": self.name, + "tokens": self.tokens, "type": self.type, } diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 17f29744..ec76ce3d 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -57,7 +57,7 @@ def json_params( upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, highres: Optional[HighresParams] = None, - parent: Dict = None, + parent: Optional[Dict] = None, ) -> Any: json = { "input_size": size.tojson(), @@ -158,6 +158,7 @@ def make_output_name( size: Size, extras: Optional[List[Optional[Param]]] = None, count: Optional[int] = None, + offset: int = 0, ) -> List[str]: count = count or params.batch now = int(time()) @@ -183,7 +184,7 @@ def make_output_name( return [ f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" - for i in range(count) + for i in range(offset, count + offset) ] diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 28825ad4..5a84aa1a 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -14,7 +14,7 @@ Point = Tuple[int, int] class SizeChart(IntEnum): - unlimited = 0 + micro = 64 mini = 128 # small tile for very expensive models half = 256 # half tile for outpainting auto = 512 # auto tile size @@ -25,6 +25,7 @@ class SizeChart(IntEnum): hd16k = 2**14 hd32k = 2**15 hd64k = 2**16 + max = 2**32 # should be a reasonable upper limit for now class TileOrder: @@ -140,7 +141,7 @@ class DeviceParams: if self.options is None: return self.provider else: - return self.provider # (self.provider, self.options) + return (self.provider, self.options) def sess_options(self, cache=True) -> SessionOptions: if cache and self.sess_options_cache is not None: @@ -201,11 +202,14 @@ class ImageParams: batch: int control: Optional[NetworkModel] input_prompt: str - input_negative_prompt: str + input_negative_prompt: Optional[str] loopback: int tiled_vae: bool - tiles: int - overlap: float + unet_tile: int + unet_overlap: float + vae_tile: int + vae_overlap: float + denoise: int def __init__( self, @@ -224,9 +228,11 @@ class ImageParams: input_negative_prompt: Optional[str] = None, loopback: int = 0, tiled_vae: bool = False, - tiles: int = 512, - overlap: float = 0.25, - stride: int = 64, + unet_overlap: float = 0.25, + unet_tile: int = 512, + vae_overlap: float = 0.25, + vae_tile: int = 512, + denoise: int = 3, ) -> None: self.model = model self.pipeline = pipeline @@ -243,14 +249,16 @@ class ImageParams: self.input_negative_prompt = input_negative_prompt or negative_prompt self.loopback = loopback self.tiled_vae = tiled_vae - self.tiles = tiles - self.overlap = overlap - self.stride = stride + self.unet_overlap = unet_overlap + self.unet_tile = unet_tile + self.vae_overlap = vae_overlap + self.vae_tile = vae_tile + self.denoise = denoise def do_cfg(self): return self.cfg > 1.0 - def get_valid_pipeline(self, group: str, pipeline: str = None) -> str: + def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str: pipeline = pipeline or self.pipeline # if the correct pipeline was already requested, simply use that @@ -259,7 +267,14 @@ class ImageParams: # otherwise, check for additional allowed pipelines if group == "img2img": - if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]: + if pipeline in [ + "controlnet", + "img2img-sdxl", + "lpw", + "panorama", + "panorama-sdxl", + "pix2pix", + ]: return pipeline elif pipeline == "txt2img-sdxl": return "img2img-sdxl" @@ -267,7 +282,7 @@ class ImageParams: if pipeline in ["controlnet", "lpw", "panorama"]: return pipeline elif group == "txt2img": - if pipeline in ["lpw", "panorama", "txt2img-sdxl"]: + if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]: return pipeline logger.debug("pipeline %s is not valid for %s", pipeline, group) @@ -280,7 +295,7 @@ class ImageParams: return self.pipeline == "lpw" def is_panorama(self): - return self.pipeline == "panorama" + return self.pipeline in ["panorama", "panorama-sdxl"] def is_pix2pix(self): return self.pipeline == "pix2pix" @@ -305,9 +320,11 @@ class ImageParams: "input_negative_prompt": self.input_negative_prompt, "loopback": self.loopback, "tiled_vae": self.tiled_vae, - "tiles": self.tiles, - "overlap": self.overlap, - "stride": self.stride, + "unet_overlap": self.unet_overlap, + "unet_tile": self.unet_tile, + "vae_overlap": self.vae_overlap, + "vae_tile": self.vae_tile, + "denoise": self.denoise, } def with_args(self, **kwargs): @@ -327,9 +344,11 @@ class ImageParams: kwargs.get("input_negative_prompt", self.input_negative_prompt), kwargs.get("loopback", self.loopback), kwargs.get("tiled_vae", self.tiled_vae), - kwargs.get("tiles", self.tiles), - kwargs.get("overlap", self.overlap), - kwargs.get("stride", self.stride), + kwargs.get("unet_overlap", self.unet_overlap), + kwargs.get("unet_tile", self.unet_tile), + kwargs.get("vae_overlap", self.vae_overlap), + kwargs.get("vae_tile", self.vae_tile), + kwargs.get("denoise", self.denoise), ) @@ -351,6 +370,17 @@ class StageParams: self.tile_order = tile_order self.tile_size = tile_size + def with_args( + self, + **kwargs, + ): + return StageParams( + name=kwargs.get("name", self.name), + outscale=kwargs.get("outscale", self.outscale), + tile_order=kwargs.get("tile_order", self.tile_order), + tile_size=kwargs.get("tile_size", self.tile_size), + ) + class UpscaleParams: def __init__( @@ -459,10 +489,14 @@ class HighresParams: self.method = method self.iterations = iterations + def outscale(self) -> int: + return self.scale**self.iterations + def resize(self, size: Size) -> Size: + outscale = self.outscale() return Size( - size.width * (self.scale**self.iterations), - size.height * (self.scale**self.iterations), + size.width * outscale, + size.height * outscale, ) def tojson(self): diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7035f680..a9162f0f 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -1,12 +1,14 @@ from io import BytesIO from logging import getLogger from os import path +from typing import Any, Dict from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate from PIL import Image from ..chain import CHAIN_STAGES, ChainPipeline +from ..chain.result import StageResult from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.run import ( run_blend_pipeline, @@ -17,7 +19,7 @@ from ..diffusers.run import ( ) from ..diffusers.utils import replace_wildcards from ..output import json_params, make_output_name -from ..params import Border, Size, StageParams, TileOrder, UpscaleParams +from ..params import Size, StageParams, TileOrder from ..transformers.run import run_txt2txt_pipeline from ..utils import ( base_join, @@ -49,10 +51,11 @@ from .load import ( get_wildcard_data, ) from .params import ( - border_from_request, - highres_from_request, + build_border, + build_highres, + build_upscale, + pipeline_from_json, pipeline_from_request, - upscale_from_request, ) from .utils import wrap_route @@ -167,8 +170,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): size = Size(source.width, source.height) device, params, _size = pipeline_from_request(server, "img2img") - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() source_filter = get_from_list( request.args, "sourceFilter", list(get_source_filters().keys()) ) @@ -216,12 +219,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server, "txt2img") - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() replace_wildcards(params, get_wildcard_data()) - output = make_output_name(server, "txt2img", params, size) + output = make_output_name(server, "txt2img", params, size, count=params.batch) job_name = output[0] pool.submit( @@ -250,7 +253,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): if mask_file is None: return error_reply("mask image is required") - source = Image.open(BytesIO(source_file.read())).convert("RGB") + source = Image.open(BytesIO(source_file.read())).convert("RGBA") size = Size(source.width, source.height) mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA") @@ -270,9 +273,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ) device, params, _size = pipeline_from_request(server, "inpaint") - expand = border_from_request() - upscale = upscale_from_request() - highres = highres_from_request() + expand = build_border() + upscale = build_upscale() + highres = build_highres() fill_color = get_not_empty(request.args, "fillColor", "white") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") @@ -340,8 +343,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request(server) - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() replace_wildcards(params, get_wildcard_data()) @@ -366,47 +369,70 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) +# keys that are specially parsed by params and should not show up in with_args +CHAIN_POP_KEYS = ["model", "control"] + + def chain(server: ServerContext, pool: DevicePoolExecutor): - logger.debug( - "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() - ) - body = request.form.get("chain") or request.files.get("chain") - if body is None: - return error_reply("chain pipeline must have a body") + if request.is_json: + logger.debug("chain pipeline request with JSON body") + data = request.get_json() + else: + logger.debug( + "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() + ) + + body = request.form.get("chain") or request.files.get("chain") + if body is None: + return error_reply("chain pipeline must have a body") + + data = load_config_str(body) - data = load_config_str(body) schema = load_config("./schemas/chain.yaml") - logger.debug("validating chain request: %s against %s", data, schema) validate(data, schema) - # get defaults from the regular parameters - device, params, size = pipeline_from_request(server) - output = make_output_name(server, "chain", params, size) - job_name = output[0] - - replace_wildcards(params, get_wildcard_data()) + device, base_params, base_size = pipeline_from_json( + server, data=data.get("defaults") + ) + # start building the pipeline pipeline = ChainPipeline() for stage_data in data.get("stages", []): stage_class = CHAIN_STAGES[stage_data.get("type")] - kwargs = stage_data.get("params", {}) + kwargs: Dict[str, Any] = stage_data.get("params", {}) logger.info("request stage: %s, %s", stage_class.__name__, kwargs) + # TODO: combine base params with stage params + _device, params, size = pipeline_from_json(server, data=kwargs) + replace_wildcards(params, get_wildcard_data()) + + # remove parsed keys, like model names (which become paths) + for pop_key in CHAIN_POP_KEYS: + if pop_key in kwargs: + kwargs.pop(pop_key) + + if "seed" in kwargs and kwargs["seed"] == -1: + kwargs.pop("seed") + + # replace kwargs with parsed versions + kwargs["params"] = params + kwargs["size"] = size + + border = build_border(kwargs) + kwargs["border"] = border + + upscale = build_upscale(kwargs) + kwargs["upscale"] = upscale + + # prepare the stage metadata stage = StageParams( stage_data.get("name", stage_class.__name__), - tile_size=get_size(kwargs.get("tile_size")), + tile_size=get_size(kwargs.get("tiles")), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), ) - if "border" in kwargs: - border = Border.even(int(kwargs.get("border"))) - kwargs["border"] = border - - if "upscale" in kwargs: - upscale = UpscaleParams(kwargs.get("upscale")) - kwargs["upscale"] = upscale - + # load any images related to this stage stage_source_name = "source:%s" % (stage.name) stage_mask_name = "mask:%s" % (stage.name) @@ -436,20 +462,25 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) + output = make_output_name( + server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0) + ) + job_name = output[0] + # build and run chain pipeline - empty_source = Image.new("RGB", (size.width, size.height)) pool.submit( job_name, pipeline, server, - params, - empty_source, - output=output[0], - size=size, + base_params, + StageResult.empty(), + output=output, + size=base_size, needs_device=device, ) - return jsonify(json_params(output, params, size)) + step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) + return jsonify(json_params(output, step_params, base_size)) def blend(server: ServerContext, pool: DevicePoolExecutor): @@ -471,7 +502,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): sources.append(source) device, params, size = pipeline_from_request(server) - upscale = upscale_from_request() + upscale = build_upscale() output = make_output_name(server, "upscale", params, size) job_name = output[0] diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 8af1eb2a..034fc3c6 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -5,18 +5,44 @@ from typing import List, Optional import torch -from ..utils import get_boolean +from ..utils import get_boolean, get_list from .model_cache import ModelCache logger = getLogger(__name__) +DEFAULT_ANY_PLATFORM = True DEFAULT_CACHE_LIMIT = 5 DEFAULT_JOB_LIMIT = 10 DEFAULT_IMAGE_FORMAT = "png" DEFAULT_SERVER_VERSION = "v0.10.0" +DEFAULT_SHOW_PROGRESS = True +DEFAULT_WORKER_RETRIES = 3 class ServerContext: + bundle_path: str + model_path: str + output_path: str + params_path: str + cors_origin: str + any_platform: bool + block_platforms: List[str] + default_platform: str + image_format: str + cache_limit: int + cache_path: str + show_progress: bool + optimizations: List[str] + extra_models: List[str] + job_limit: int + memory_limit: int + admin_token: str + server_version: str + worker_retries: int + feature_flags: List[str] + plugins: List[str] + debug: bool + def __init__( self, bundle_path: str = ".", @@ -24,19 +50,23 @@ class ServerContext: output_path: str = ".", params_path: str = ".", cors_origin: str = "*", - any_platform: bool = True, + any_platform: bool = DEFAULT_ANY_PLATFORM, block_platforms: Optional[List[str]] = None, default_platform: Optional[str] = None, image_format: str = DEFAULT_IMAGE_FORMAT, cache_limit: int = DEFAULT_CACHE_LIMIT, cache_path: Optional[str] = None, - show_progress: bool = True, + show_progress: bool = DEFAULT_SHOW_PROGRESS, optimizations: Optional[List[str]] = None, extra_models: Optional[List[str]] = None, job_limit: int = DEFAULT_JOB_LIMIT, memory_limit: Optional[int] = None, admin_token: Optional[str] = None, server_version: Optional[str] = DEFAULT_SERVER_VERSION, + worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES, + feature_flags: Optional[List[str]] = None, + plugins: Optional[List[str]] = None, + debug: bool = False, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -56,6 +86,10 @@ class ServerContext: self.memory_limit = memory_limit self.admin_token = admin_token or token_urlsafe() self.server_version = server_version + self.worker_retries = worker_retries + self.feature_flags = feature_flags or [] + self.plugins = plugins or [] + self.debug = debug self.cache = ModelCache(self.cache_limit) @@ -72,26 +106,41 @@ class ServerContext: model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - # others - cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), - any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), - block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), + cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"), + any_platform=get_boolean( + environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM + ), + block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), + image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), - show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), - optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), - extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), + show_progress=get_boolean( + environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS + ), + optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), + extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), server_version=environ.get( "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION ), + worker_retries=int( + environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) + ), + feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), + plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), + debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), ) + def has_feature(self, flag: str) -> bool: + return flag in self.feature_flags + + def has_optimization(self, opt: str) -> bool: + return opt in self.optimizations + def torch_dtype(self): - if "torch-fp16" in self.optimizations: + if self.has_optimization("torch-fp16"): return torch.float16 else: return torch.float32 diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index b59bb73a..052a5c42 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -134,25 +134,44 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str: def apply_patch_basicsr(server: ServerContext): logger.debug("patching BasicSR module") - import basicsr.utils.download_util + try: + import basicsr.utils.download_util - basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl - basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) + basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl + basicsr.utils.download_util.load_file_from_url = partial( + patch_cache_path, server + ) + except ImportError: + logger.info("unable to import basicsr utils for patching") + except AttributeError: + logger.warning("unable to patch basicsr utils") def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") - import codeformer.facelib.utils.misc + try: + import codeformer.facelib.utils.misc - codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl - codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) + codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl + codeformer.facelib.utils.misc.load_file_from_url = partial( + patch_cache_path, server + ) + except ImportError: + logger.info("unable to import codeformer utils for patching") + except AttributeError: + logger.warning("unable to patch codeformer utils") def apply_patch_facexlib(server: ServerContext): logger.debug("patching Facexlib module") - import facexlib.utils + try: + import facexlib.utils - facexlib.utils.load_file_from_url = partial(patch_cache_path, server) + facexlib.utils.load_file_from_url = partial(patch_cache_path, server) + except ImportError: + logger.info("unable to import facexlib for patching") + except AttributeError: + logger.warning("unable to patch facexlib utils") def apply_patches(server: ServerContext): diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index a3759ff6..6bf1de2d 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list) # Loaded from extra_models extra_hashes: Dict[str, str] = {} extra_strings: Dict[str, Any] = {} +extra_tokens: Dict[str, List[str]] = {} def get_config_params(): @@ -160,9 +161,10 @@ def load_extras(server: ServerContext): """ global extra_hashes global extra_strings + global extra_tokens - labels = {} - strings = {} + labels: Dict[str, str] = {} + strings: Dict[str, Any] = {} extra_schema = load_config("./schemas/extras.yaml") @@ -210,6 +212,14 @@ def load_extras(server: ServerContext): else: labels[model_name] = model["label"] + if "tokens" in model: + logger.debug( + "collecting tokens for model %s from %s", + model_name, + file, + ) + extra_tokens[model_name] = model["tokens"] + if "inversions" in model: for inversion in model["inversions"]: if "label" in inversion: @@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None: ) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) network_models.extend( - [NetworkModel(model, "inversion") for model in inversion_models] + [ + NetworkModel(model, "inversion", tokens=extra_tokens.get(model, [])) + for model in inversion_models + ] ) lora_models = list_model_globs( @@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None: base_path=path.join(server.model_path, "lora"), ) logger.debug("loaded LoRA models from disk: %s", lora_models) - network_models.extend([NetworkModel(model, "lora") for model in lora_models]) + network_models.extend( + [ + NetworkModel(model, "lora", tokens=extra_tokens.get(model, [])) + for model in lora_models + ] + ) def load_params(server: ServerContext) -> None: @@ -397,7 +415,7 @@ def load_platforms(server: ServerContext) -> None: ): if potential == "cuda" or potential == "rocm": for i in range(torch.cuda.device_count()): - options = { + options: Dict[str, Union[int, str]] = { "device_id": i, } diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 21da25f4..6525d4ae 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -51,7 +51,7 @@ class ModelCache: return for i in range(len(cache)): - t, k, v = cache[i] + t, k, _v = cache[i] if tag == t and key != k: logger.debug("updating model cache: %s %s", tag, key) cache[i] = (tag, key, value) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 2598e819..e641b953 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,10 +1,10 @@ from logging import getLogger -from typing import Tuple +from typing import Dict, Optional, Tuple -import numpy as np from flask import request from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers +from ..diffusers.utils import random_seed from ..params import ( Border, DeviceParams, @@ -34,143 +34,122 @@ from .utils import get_model_path logger = getLogger(__name__) -def pipeline_from_request( - server: ServerContext, - default_pipeline: str = "txt2img", -) -> Tuple[DeviceParams, ImageParams, Size]: - user = request.remote_addr - +def build_device( + _server: ServerContext, + data: Dict[str, str], +) -> Optional[DeviceParams]: # platform stuff device = None - device_name = request.args.get("platform") + device_name = data.get("platform") if device_name is not None and device_name != "any": for platform in get_available_platforms(): if platform.device == device_name: device = platform + return device + + +def build_params( + server: ServerContext, + default_pipeline: str, + data: Dict[str, str], +) -> ImageParams: # diffusion model - model = get_not_empty(request.args, "model", get_config_value("model")) + model = get_not_empty(data, "model", get_config_value("model")) model_path = get_model_path(server, model) control = None - control_name = request.args.get("control") + control_name = data.get("control") for network in get_network_models(): if network.name == control_name: control = network # pipeline stuff pipeline = get_from_list( - request.args, "pipeline", get_available_pipelines(), default_pipeline + data, "pipeline", get_available_pipelines(), default_pipeline ) - scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers()) + scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers()) if scheduler is None: scheduler = get_config_value("scheduler") # prompt does not come from config - prompt = request.args.get("prompt", "") - negative_prompt = request.args.get("negativePrompt", None) + prompt = data.get("prompt", "") + negative_prompt = data.get("negativePrompt", None) if negative_prompt is not None and negative_prompt.strip() == "": negative_prompt = None # image params batch = get_and_clamp_int( - request.args, + data, "batch", get_config_value("batch"), get_config_value("batch", "max"), get_config_value("batch", "min"), ) cfg = get_and_clamp_float( - request.args, + data, "cfg", get_config_value("cfg"), get_config_value("cfg", "max"), get_config_value("cfg", "min"), ) eta = get_and_clamp_float( - request.args, + data, "eta", get_config_value("eta"), get_config_value("eta", "max"), get_config_value("eta", "min"), ) loopback = get_and_clamp_int( - request.args, + data, "loopback", get_config_value("loopback"), get_config_value("loopback", "max"), get_config_value("loopback", "min"), ) steps = get_and_clamp_int( - request.args, + data, "steps", get_config_value("steps"), get_config_value("steps", "max"), get_config_value("steps", "min"), ) - height = get_and_clamp_int( - request.args, - "height", - get_config_value("height"), - get_config_value("height", "max"), - get_config_value("height", "min"), + tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae")) + unet_overlap = get_and_clamp_float( + data, + "unet_overlap", + get_config_value("unet_overlap"), + get_config_value("unet_overlap", "max"), + get_config_value("unet_overlap", "min"), ) - width = get_and_clamp_int( - request.args, - "width", - get_config_value("width"), - get_config_value("width", "max"), - get_config_value("width", "min"), + unet_tile = get_and_clamp_int( + data, + "unet_tile", + get_config_value("unet_tile"), + get_config_value("unet_tile", "max"), + get_config_value("unet_tile", "min"), ) - tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE")) - tiles = get_and_clamp_int( - request.args, - "tiles", - get_config_value("tiles"), - get_config_value("tiles", "max"), - get_config_value("tiles", "min"), + vae_overlap = get_and_clamp_float( + data, + "vae_overlap", + get_config_value("vae_overlap"), + get_config_value("vae_overlap", "max"), + get_config_value("vae_overlap", "min"), ) - overlap = get_and_clamp_float( - request.args, - "overlap", - get_config_value("overlap"), - get_config_value("overlap", "max"), - get_config_value("overlap", "min"), - ) - stride = get_and_clamp_int( - request.args, - "stride", - get_config_value("stride"), - get_config_value("stride", "max"), - get_config_value("stride", "min"), + vae_tile = get_and_clamp_int( + data, + "vae_tile", + get_config_value("vae_tile"), + get_config_value("vae_tile", "max"), + get_config_value("vae_tile", "min"), ) - if stride > tiles: - logger.info("limiting stride to tile size, %s > %s", stride, tiles) - stride = tiles - - seed = int(request.args.get("seed", -1)) + seed = int(data.get("seed", -1)) if seed == -1: - # this one can safely use np.random because it produces a single value - seed = np.random.randint(np.iinfo(np.int32).max) - - logger.info( - "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", - user, - steps, - scheduler, - model_path, - pipeline, - device or "any device", - width, - height, - cfg, - seed, - prompt, - ) + seed = random_seed() params = ImageParams( model_path, @@ -186,38 +165,65 @@ def pipeline_from_request( control=control, loopback=loopback, tiled_vae=tiled_vae, - tiles=tiles, - overlap=overlap, - stride=stride, + unet_overlap=unet_overlap, + unet_tile=unet_tile, + vae_overlap=vae_overlap, + vae_tile=vae_tile, ) - size = Size(width, height) - return (device, params, size) + + return params -def border_from_request() -> Border: +def build_size( + _server: ServerContext, + data: Dict[str, str], +) -> Size: + height = get_and_clamp_int( + data, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) + width = get_and_clamp_int( + data, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) + return Size(width, height) + + +def build_border( + data: Dict[str, str] = None, +) -> Border: + if data is None: + data = request.args + left = get_and_clamp_int( - request.args, + data, "left", get_config_value("left"), get_config_value("left", "max"), get_config_value("left", "min"), ) right = get_and_clamp_int( - request.args, + data, "right", get_config_value("right"), get_config_value("right", "max"), get_config_value("right", "min"), ) top = get_and_clamp_int( - request.args, + data, "top", get_config_value("top"), get_config_value("top", "max"), get_config_value("top", "min"), ) bottom = get_and_clamp_int( - request.args, + data, "bottom", get_config_value("bottom"), get_config_value("bottom", "max"), @@ -227,46 +233,51 @@ def border_from_request() -> Border: return Border(left, right, top, bottom) -def upscale_from_request() -> UpscaleParams: +def build_upscale( + data: Dict[str, str] = None, +) -> UpscaleParams: + if data is None: + data = request.args + denoise = get_and_clamp_float( - request.args, + data, "denoise", get_config_value("denoise"), get_config_value("denoise", "max"), get_config_value("denoise", "min"), ) scale = get_and_clamp_int( - request.args, + data, "scale", get_config_value("scale"), get_config_value("scale", "max"), get_config_value("scale", "min"), ) outscale = get_and_clamp_int( - request.args, + data, "outscale", get_config_value("outscale"), get_config_value("outscale", "max"), get_config_value("outscale", "min"), ) - upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) - correction = get_from_list(request.args, "correction", get_correction_models()) - faces = get_not_empty(request.args, "faces", "false") == "true" + upscaling = get_from_list(data, "upscaling", get_upscaling_models()) + correction = get_from_list(data, "correction", get_correction_models()) + faces = get_not_empty(data, "faces", "false") == "true" face_outscale = get_and_clamp_int( - request.args, + data, "faceOutscale", get_config_value("faceOutscale"), get_config_value("faceOutscale", "max"), get_config_value("faceOutscale", "min"), ) face_strength = get_and_clamp_float( - request.args, + data, "faceStrength", get_config_value("faceStrength"), get_config_value("faceStrength", "max"), get_config_value("faceStrength", "min"), ) - upscale_order = request.args.get("upscaleOrder", "correction-first") + upscale_order = data.get("upscaleOrder", "correction-first") return UpscaleParams( upscaling, @@ -282,37 +293,43 @@ def upscale_from_request() -> UpscaleParams: ) -def highres_from_request() -> HighresParams: - enabled = get_boolean(request.args, "highres", get_config_value("highres")) +def build_highres( + data: Dict[str, str] = None, +) -> HighresParams: + if data is None: + data = request.args + + enabled = get_boolean(data, "highres", get_config_value("highres")) iterations = get_and_clamp_int( - request.args, + data, "highresIterations", get_config_value("highresIterations"), get_config_value("highresIterations", "max"), get_config_value("highresIterations", "min"), ) - method = get_from_list(request.args, "highresMethod", get_highres_methods()) + method = get_from_list(data, "highresMethod", get_highres_methods()) scale = get_and_clamp_int( - request.args, + data, "highresScale", get_config_value("highresScale"), get_config_value("highresScale", "max"), get_config_value("highresScale", "min"), ) steps = get_and_clamp_int( - request.args, + data, "highresSteps", get_config_value("highresSteps"), get_config_value("highresSteps", "max"), get_config_value("highresSteps", "min"), ) strength = get_and_clamp_float( - request.args, + data, "highresStrength", get_config_value("highresStrength"), get_config_value("highresStrength", "max"), get_config_value("highresStrength", "min"), ) + return HighresParams( enabled, scale, @@ -321,3 +338,50 @@ def highres_from_request() -> HighresParams: method=method, iterations=iterations, ) + + +PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] + + +def pipeline_from_json( + server: ServerContext, + data: Dict[str, str], + default_pipeline: str = "txt2img", +) -> PipelineParams: + """ + Like pipeline_from_request but expects a nested structure. + """ + + device = build_device(server, data.get("device", data)) + params = build_params(server, default_pipeline, data.get("params", data)) + size = build_size(server, data.get("params", data)) + + return (device, params, size) + + +def pipeline_from_request( + server: ServerContext, + default_pipeline: str = "txt2img", +) -> PipelineParams: + user = request.remote_addr + + device = build_device(server, request.args) + params = build_params(server, default_pipeline, request.args) + size = build_size(server, request.args) + + logger.info( + "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", + user, + params.steps, + params.scheduler, + params.model, + params.pipeline, + device or "any device", + size.width, + size.height, + params.cfg, + params.seed, + params.prompt, + ) + + return (device, params, size) diff --git a/api/onnx_web/server/plugin.py b/api/onnx_web/server/plugin.py new file mode 100644 index 00000000..022047df --- /dev/null +++ b/api/onnx_web/server/plugin.py @@ -0,0 +1,61 @@ +from importlib import import_module +from logging import getLogger +from typing import Any, Callable, Dict + +from onnx_web.chain.stages import add_stage +from onnx_web.diffusers.load import add_pipeline +from onnx_web.server.context import ServerContext + +logger = getLogger(__name__) + + +class PluginExports: + pipelines: Dict[str, Any] + stages: Dict[str, Any] + + def __init__(self, pipelines=None, stages=None) -> None: + self.pipelines = pipelines or {} + self.stages = stages or {} + + +PluginModule = Callable[[ServerContext], PluginExports] + + +def load_plugins(server: ServerContext) -> PluginExports: + combined_exports = PluginExports() + + for plugin in server.plugins: + logger.info("loading plugin module: %s", plugin) + try: + module: PluginModule = import_module(plugin) + exports = module(server) + + for name, pipeline in exports.pipelines.items(): + if name in combined_exports.pipelines: + logger.warning( + "multiple plugins exported a pipeline named %s", name + ) + else: + combined_exports.pipelines[name] = pipeline + + for name, stage in exports.stages.items(): + if name in combined_exports.stages: + logger.warning("multiple plugins exported a stage named %s", name) + else: + combined_exports.stages[name] = stage + except Exception: + logger.exception("error importing plugin") + + return combined_exports + + +def register_plugins(exports: PluginExports) -> bool: + success = True + + for name, pipeline in exports.pipelines.items(): + success = success and add_pipeline(name, pipeline) + + for name, stage in exports.stages.items(): + success = success and add_stage(name, stage) + + return success diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 465e5aa8..d047ec0d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -18,6 +18,11 @@ logger = getLogger(__name__) SAFE_CHARS = "._-" +def split_list(val: str) -> List[str]: + parts = [part.strip() for part in val.split(",")] + return [part for part in parts if len(part) > 0] + + def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") return path.join(base, tail_path) @@ -28,7 +33,16 @@ def is_debug() -> bool: def get_boolean(args: Any, key: str, default_value: bool) -> bool: - return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes") + val = args.get(key, str(default_value)) + + if isinstance(val, bool): + return val + + return val.lower() in ("1", "t", "true", "y", "yes") + + +def get_list(args: Any, key: str, default="") -> List[str]: + return split_list(args.get(key, default)) def get_and_clamp_float( @@ -61,13 +75,13 @@ def get_from_list( def get_from_map( - args: Any, key: str, values: Dict[str, TElem], default: TElem + args: Any, key: str, values: Dict[str, TElem], default_key: str ) -> TElem: - selected = args.get(key, default) + selected = args.get(key, default_key) if selected in values: return values[selected] else: - return values[default] + return values[default_key] def get_not_empty(args: Any, key: str, default: TElem) -> TElem: @@ -195,6 +209,8 @@ def load_config(file: str) -> Dict: return load_yaml(file) elif ext in [".json"]: return load_json(file) + else: + raise ValueError("unknown config file extension") def load_config_str(raw: str) -> Dict: diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 22851f77..2d6d0278 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -25,6 +25,7 @@ class WorkerContext: idle: "Value[bool]" timeout: float retries: int + initial_retries: int def __init__( self, @@ -36,6 +37,8 @@ class WorkerContext: progress: "Queue[ProgressCommand]", active_pid: "Value[int]", idle: "Value[bool]", + retries: int, + timeout: float, ): self.job = None self.name = name @@ -47,12 +50,13 @@ class WorkerContext: self.active_pid = active_pid self.last_progress = None self.idle = idle - self.timeout = 1.0 - self.retries = 3 # TODO: get from env + self.initial_retries = retries + self.retries = retries + self.timeout = timeout def start(self, job: str) -> None: self.job = job - self.retries = 3 + self.retries = self.initial_retries self.set_cancel(cancel=False) self.set_idle(idle=False) @@ -82,7 +86,7 @@ class WorkerContext: return 0 def get_progress_callback(self) -> ProgressCallback: - from ..chain.base import ChainProgress + from ..chain.pipeline import ChainProgress def on_progress(step: int, timestep: int, latents: Any): on_progress.step = step diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index c65ebf10..3b0d32a8 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -86,15 +86,15 @@ class DevicePoolExecutor: self.logs = Queue(self.max_pending_per_worker) self.rlock = Lock() - def start(self) -> None: + def start(self, *args) -> None: self.create_health_worker() self.create_logger_worker() self.create_progress_worker() for device in self.devices: - self.create_device_worker(device) + self.create_device_worker(device, *args) - def create_device_worker(self, device: DeviceParams) -> None: + def create_device_worker(self, device: DeviceParams, *args) -> None: name = device.device # always recreate queues @@ -124,15 +124,17 @@ class DevicePoolExecutor: pending=self.pending[name], active_pid=current, idle=self.worker_idle[name], + retries=self.server.worker_retries, + timeout=self.progress_interval, ) self.context[name] = context worker = Process( name=f"onnx-web worker: {name}", target=worker_main, - args=(context, self.server), + args=(context, self.server, *args), + daemon=True, ) - worker.daemon = True self.workers[name] = worker logger.debug("starting worker for device %s", device) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 361c1150..55ebcaac 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -27,10 +27,14 @@ MEMORY_ERRORS = [ ] -def worker_main(worker: WorkerContext, server: ServerContext): - apply_patches(server) +def worker_main( + worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True +): setproctitle("onnx-web worker: %s" % (worker.device.device)) + if patch: + apply_patches(server) + logger.trace( "checking in from worker with providers: %s", get_available_providers() ) @@ -46,7 +50,7 @@ def worker_main(worker: WorkerContext, server: ServerContext): getpid(), worker.get_active(), ) - exit(EXIT_REPLACED) + return exit(EXIT_REPLACED) # wait briefly for the next job job = worker.pending.get(timeout=worker.timeout) @@ -69,15 +73,15 @@ def worker_main(worker: WorkerContext, server: ServerContext): except KeyboardInterrupt: logger.debug("worker got keyboard interrupt") worker.fail() - exit(EXIT_INTERRUPT) + return exit(EXIT_INTERRUPT) except RetryException: logger.exception("retry error in worker, exiting") worker.fail() - exit(EXIT_ERROR) + return exit(EXIT_ERROR) except ValueError: logger.exception("value error in worker, exiting") worker.fail() - exit(EXIT_ERROR) + return exit(EXIT_ERROR) except Exception as e: e_str = str(e) # restart the worker on memory errors @@ -85,7 +89,7 @@ def worker_main(worker: WorkerContext, server: ServerContext): if e_mem in e_str: logger.error("detected out-of-memory error, exiting: %s", e) worker.fail() - exit(EXIT_MEMORY) + return exit(EXIT_MEMORY) # carry on for other errors logger.exception( diff --git a/api/params.json b/api/params.json index c4a1ee32..d9cd6fa0 100644 --- a/api/params.json +++ b/api/params.json @@ -98,7 +98,7 @@ "highresSteps": { "default": 0, "min": 1, - "max": 200, + "max": 500, "step": 1 }, "highresStrength": { @@ -141,12 +141,6 @@ "max": 4, "step": 1 }, - "overlap": { - "default": 0.25, - "min": 0.0, - "max": 0.9, - "step": 0.01 - }, "pipeline": { "default": "", "keys": [ @@ -188,7 +182,7 @@ "steps": { "default": 25, "min": 1, - "max": 200, + "max": 300, "step": 1 }, "strength": { @@ -197,21 +191,9 @@ "max": 1, "step": 0.01 }, - "stride": { - "default": 128, - "min": 64, - "max": 512, - "step": 64 - }, - "tiledVAE": { + "tiled_vae": { "default": false }, - "tiles": { - "default": 512, - "min": 128, - "max": 2048, - "step": 128 - }, "tileOrder": { "default": "spiral", "keys": [ @@ -225,6 +207,18 @@ "max": 1024, "step": 8 }, + "unet_overlap": { + "default": 0.25, + "min": 0.0, + "max": 0.9, + "step": 0.01 + }, + "unet_tile": { + "default": 512, + "min": 128, + "max": 2048, + "step": 128 + }, "upscaleOrder": { "default": "correction-first", "keys": [ @@ -237,6 +231,18 @@ "default": "", "keys": [] }, + "vae_overlap": { + "default": 0.25, + "min": 0.0, + "max": 0.9, + "step": 0.01 + }, + "vae_tile": { + "default": 512, + "min": 256, + "max": 1024, + "step": 128 + }, "width": { "default": 512, "min": 128, diff --git a/api/pyproject.toml b/api/pyproject.toml index efed7be3..5d69e906 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -9,12 +9,14 @@ skip_glob = ["*/lpw.py"] [tool.mypy] # ignore_missing_imports = true exclude = [ - "onnx_web.diffusers.lpw_stable_diffusion_onnx" + "onnx_web.diffusers.pipelines.controlnet", + "onnx_web.diffusers.pipelines.lpw", + "onnx_web.diffusers.pipelines.pix2pix" ] [[tool.mypy.overrides]] module = [ -"arpeggio", + "arpeggio", "basicsr.archs.rrdbnet_arch", "basicsr.utils.download_util", "basicsr.utils", @@ -27,8 +29,10 @@ module = [ "compel", "controlnet_aux", "cv2", + "debugpy", "diffusers", "diffusers.configuration_utils", + "diffusers.image_processor", "diffusers.loaders", "diffusers.models.attention_processor", "diffusers.models.autoencoder_kl", @@ -41,9 +45,10 @@ module = [ "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", "diffusers.pipelines.onnx_utils", "diffusers.pipelines.paint_by_example", + "diffusers.pipelines.pipeline_utils", "diffusers.pipelines.stable_diffusion", "diffusers.pipelines.stable_diffusion.convert_from_ckpt", - "diffusers.pipeline_utils", + "diffusers.pipelines.stable_diffusion_xl", "diffusers.schedulers", "diffusers.utils.logging", "facexlib.utils", @@ -56,11 +61,17 @@ module = [ "mediapipe", "onnxruntime", "onnxruntime.transformers.float16", + "optimum.exporters.onnx", + "optimum.onnxruntime", + "optimum.onnxruntime.modeling_diffusion", + "optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img", + "optimum.pipelines.diffusers.pipeline_utils", "piexif", "piexif.helper", "realesrgan", "realesrgan.archs.srvgg_arch", "safetensors", + "scipy", "timm.models.layers", "transformers", "win10toast" diff --git a/api/schemas/chain.yaml b/api/schemas/chain.yaml index 1b41c503..e24593f6 100644 --- a/api/schemas/chain.yaml +++ b/api/schemas/chain.yaml @@ -46,17 +46,31 @@ $defs: patternProperties: "^[-_A-Za-z]+$": oneOf: + - type: boolean - type: number - type: string + - type: "null" request_chain: type: array items: $ref: "#/$defs/request_stage" + request_defaults: + type: object + properties: + txt2img: + $ref: "#/$defs/image_params" + img2img: + $ref: "#/$defs/image_params" + type: object additionalProperties: False required: [stages] properties: + defaults: + $ref: "#/$defs/request_defaults" + platform: + type: string stages: $ref: "#/$defs/request_chain" diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index eea6dbf9..518023e0 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -10,34 +10,53 @@ $defs: - type: number - type: string - lora_network: + tensor_format: + type: string + enum: [bin, ckpt, onnx, pt, pth, safetensors] + + embedding_network: type: object required: [name, source] properties: - name: - type: string - source: - type: string + format: + $ref: "#/$defs/tensor_format" label: type: string - weight: - type: number - - textual_inversion_network: - type: object - required: [name, source] - properties: - name: - type: string - source: - type: string - format: + model: type: string enum: [concept, embeddings] - label: + name: + type: string + source: type: string token: type: string + type: + type: string + const: inversion # TODO: add embedding + weight: + type: number + + lora_network: + type: object + required: [name, source, type] + properties: + label: + type: string + model: + type: string + enum: [cloneofsimo, sd-scripts] + name: + type: string + source: + type: string + tokens: + type: array + items: + type: string + type: + type: string + const: lora weight: type: number @@ -46,8 +65,7 @@ $defs: required: [name, source] properties: format: - type: string - enum: [bin, ckpt, onnx, pt, pth, safetensors] + $ref: "#/$defs/tensor_format" half: type: boolean label: @@ -85,7 +103,7 @@ $defs: inversions: type: array items: - $ref: "#/$defs/textual_inversion_network" + $ref: "#/$defs/embedding_network" loras: type: array items: @@ -100,6 +118,7 @@ $defs: panorama, pix2pix, txt2img, + txt2img-sdxl, upscale, ] vae: @@ -141,31 +160,6 @@ $defs: source: type: string - source_network: - type: object - required: [name, source, type] - properties: - format: - type: string - enum: [bin, ckpt, onnx, pt, pth, safetensors] - model: - type: string - enum: [ - # inversion - concept, - embeddings, - # lora - cloneofsimo, - sd-scripts - ] - name: - type: string - source: - type: string - type: - type: string - enum: [inversion, lora] - translation: type: object additionalProperties: False @@ -193,7 +187,9 @@ properties: networks: type: array items: - $ref: "#/$defs/source_network" + oneOf: + - $ref: "#/$defs/lora_network" + - $ref: "#/$defs/embedding_network" sources: type: array items: diff --git a/api/scripts/onnx-lora.py b/api/scripts/onnx-lora.py new file mode 100644 index 00000000..14e72d14 --- /dev/null +++ b/api/scripts/onnx-lora.py @@ -0,0 +1,74 @@ +from argparse import ArgumentParser +from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors +from os import path +from onnx.checker import check_model +from onnx.external_data_helper import ( + convert_model_to_external_data, + write_external_data_tensors, +) +from onnxruntime import InferenceSession, SessionOptions +from logging import getLogger + +from onnx_web.convert.utils import ConversionContext + +logger = getLogger(__name__) + + +if __name__ == "__main__": + context = ConversionContext.from_environ() + parser = ArgumentParser() + parser.add_argument("--base", type=str) + parser.add_argument("--dest", type=str) + parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) + parser.add_argument("--lora_models", nargs="+", type=str, default=[]) + parser.add_argument("--lora_weights", nargs="+", type=float, default=[]) + + args = parser.parse_args() + logger.info( + "merging %s with %s with weights: %s", + args.lora_models, + args.base, + args.lora_weights, + ) + + default_weight = 1.0 / len(args.lora_models) + while len(args.lora_weights) < len(args.lora_models): + args.lora_weights.append(default_weight) + + blend_model = blend_loras( + context, + args.base, + list(zip(args.lora_models, args.lora_weights)), + args.type, + ) + if args.dest is None or args.dest == "" or args.dest == ":load": + # convert to external data and save to memory + (bare_model, external_data) = buffer_external_data_tensors(blend_model) + logger.info("saved external data for %s nodes", len(external_data)) + + external_names, external_values = zip(*external_data) + opts = SessionOptions() + opts.add_external_initializers(list(external_names), list(external_values)) + sess = InferenceSession( + bare_model.SerializeToString(), + sess_options=opts, + providers=["CPUExecutionProvider"], + ) + logger.info( + "successfully loaded blended model: %s", [i.name for i in sess.get_inputs()] + ) + else: + convert_model_to_external_data( + blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb" + ) + bare_model = write_external_data_tensors(blend_model, args.dest) + dest_file = path.join(args.dest, f"lora-{args.type}.onnx") + + with open(dest_file, "w+b") as model_file: + model_file.write(bare_model.SerializeToString()) + + logger.info("successfully saved blended model: %s", dest_file) + + check_model(dest_file) + + logger.info("checked blended model") diff --git a/api/scripts/test-refs/blend-512-muffin-white-0.png b/api/scripts/test-refs/blend-512-muffin-white-0.png index 103f582c..f218fb5a 100644 --- a/api/scripts/test-refs/blend-512-muffin-white-0.png +++ b/api/scripts/test-refs/blend-512-muffin-white-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:be25e2a6252de2cd830c6421f75313067bf3ef29904138536615899c7169ac57 -size 573669 +oid sha256:14833dae2dafa6eb2fe9184087a4ba0781b2be881e2fdedb5dca09baf1843799 +size 572108 diff --git a/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png b/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png index 6a1fbb30..4bfe6bd7 100644 --- a/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png +++ b/api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1a04ffe1ac885d30c782dd4413a5ed0b14632ab197c5fae5ad86b6f978b39e7 -size 1440395 +oid sha256:613ce059320abadb89f4adf00546d45a20d504ea508106499ceca78df389515f +size 1469930 diff --git a/api/scripts/test-refs/outpaint-even-256-0.png b/api/scripts/test-refs/outpaint-even-256-0.png index fdc925e8..c69f41b5 100644 --- a/api/scripts/test-refs/outpaint-even-256-0.png +++ b/api/scripts/test-refs/outpaint-even-256-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6d7d666f57ef942b5b7706b93df8831c06a6872287b44e012bb5a18959d38a63 -size 2018986 +oid sha256:60f019055396a71d387f9c575acac1802b639cdf8ae95ea0287278fa68f051ac +size 1968972 diff --git a/api/scripts/test-refs/outpaint-horizontal-512-0.png b/api/scripts/test-refs/outpaint-horizontal-512-0.png index e6669885..311c20fa 100644 --- a/api/scripts/test-refs/outpaint-horizontal-512-0.png +++ b/api/scripts/test-refs/outpaint-horizontal-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0679d0198ec0f321a83c214a1fa178c64aa1f9139d2ba560dac76b653d94fc01 -size 1580841 +oid sha256:b1d1b3417ce63553a15c80e601cbccde7b2ef4af05eef624cfc0d82d5e6a8b35 +size 1579738 diff --git a/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png b/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png index 56761cfc..14d22eee 100644 --- a/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png +++ b/api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:151c811488c933d9b858b70157da49ef626e2ce99207120fddc3791238e5a065 -size 1949395 +oid sha256:f6c4cd00f206bc3127c888dbb026edcef08fcc86be0487bf73abb77fd64bc419 +size 1680355 diff --git a/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png b/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png index f467648a..fdf38af0 100644 --- a/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png +++ b/api/scripts/test-refs/outpaint-panorama-vertical-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2311f9f2c273065a854636113f43fbbae642792e7c7f265c8bb3cbda0e308182 -size 1894231 +oid sha256:2c5486aebb193a2cfc155553934b33ad8fe224ef7fbf04b56dbefdea3ac14a30 +size 1584202 diff --git a/api/scripts/test-refs/outpaint-vertical-512-0.png b/api/scripts/test-refs/outpaint-vertical-512-0.png index 2933eb8e..d5f52320 100644 --- a/api/scripts/test-refs/outpaint-vertical-512-0.png +++ b/api/scripts/test-refs/outpaint-vertical-512-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8caa4f99dc1520a5d30d563090aa6cc02db89c0abb3eec7f8c3f8750c9dbf20f -size 1514972 +oid sha256:e59324850e153e768c7f849026fb0753440906a54223d3583f0b18caec6c8ff0 +size 1523624 diff --git a/api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png b/api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png index 9e68e223..b52fbe68 100644 --- a/api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png +++ b/api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ef46449cb985601a9fd1facdbb2cdef3e6f146fd3ad7bf3745540b1a571be2dc -size 1317089 +oid sha256:67efb0d6d889c6fddf590876415a7b1cdfd488a3d63082a8c022c42fa1c022ed +size 1386630 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png index 05c93b7d..7919d7d7 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d4f5506f9f3836dc3383f695c533b1ff2873f60c76b65a37f0850e2231a29495 -size 522478 +oid sha256:93def1c9b1355ed33d25916df16f037a11ba4ba3ee6bcd487d58818371fc7ad5 +size 526093 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png index dca24a75..00408050 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ce2c5e5510c93d5d4e2e7688cd132e1c6e167e7715173e1250e86f341ac47172 -size 503943 +oid sha256:bb9043d076c9084fb3f74fa5f58ec734c781ef403ae9723be261dff27e191f5e +size 494092 diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png index 7b1a789a..59c02b8a 100644 --- a/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png +++ b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5dee66e2a5b003c6e892e2f5e7b7ba95c471ee301cdad34e11199083544a46d8 -size 502780 +oid sha256:582669cbcc215d32728c18166a46312c583462a7f438ff69108e4854b6c5edf7 +size 501499 diff --git a/api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png b/api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png index 0b713be1..7e52eeb7 100644 --- a/api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png +++ b/api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1a982c32e952f9fe1d9f0ba1b93f177f04f11b23603736ddc157ece2d45d915e -size 7313660 +oid sha256:30c2845fc375cd965a574c88f1da9e9c6b1d679514e5f3c4578a5d962726d4dd +size 6498264 diff --git a/api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png b/api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png index c37edc5e..a6f36771 100644 --- a/api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png +++ b/api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a7c381c34884e1c1304fd7f7a7b631c6ab4ba3694ececd99c3b5f43cba9ff5b6 -size 7313667 +oid sha256:9153a0b9b12a0c9c7309b9023477d74fc824e3820cb19aaadaf99cb80be72ace +size 6711175 diff --git a/api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png b/api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png index cda047d7..a6f36771 100644 --- a/api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png +++ b/api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f3f34f029ca3b7324b2f7ff851e5af86454cd02c0b7445676655bfd486b8c2c0 -size 7313651 +oid sha256:9153a0b9b12a0c9c7309b9023477d74fc824e3820cb19aaadaf99cb80be72ace +size 6711175 diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index 2e6d7fd2..42806b55 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -30,6 +30,10 @@ FAST_TEST = 10 SLOW_TEST = 25 VERY_SLOW_TEST = 75 +STRICT_TEST = 1e-4 +LOOSE_TEST = 1e-2 +VERY_LOOSE_TEST = 0.025 + def test_path(relpath: str) -> str: return path.join(path.dirname(__file__), relpath) @@ -41,7 +45,7 @@ class TestCase: name: str, query: str, max_attempts: int = FAST_TEST, - mse_threshold: float = 1e-4, + mse_threshold: float = STRICT_TEST, source: Union[Image.Image, List[Image.Image]] = None, mask: Image.Image = None, ) -> None: @@ -65,6 +69,7 @@ TEST_DATA = [ TestCase( "txt2img-sd-v1-5-512-muffin-deis", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", + mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v1-5-512-muffin-dpm", @@ -73,10 +78,12 @@ TEST_DATA = [ TestCase( "txt2img-sd-v1-5-512-muffin-heun", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun", + mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v1-5-512-muffin-unipc", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi", + mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v2-1-512-muffin", @@ -84,7 +91,7 @@ TEST_DATA = [ ), TestCase( "txt2img-sd-v2-1-768-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768", max_attempts=SLOW_TEST, ), TestCase( @@ -106,7 +113,7 @@ TEST_DATA = [ ), TestCase( "img2img-sd-v1-5-256-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256", source="txt2img-sd-v1-5-256-muffin-0", ), TestCase( @@ -130,7 +137,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=SLOW_TEST, - mse_threshold=0.025, + mse_threshold=VERY_LOOSE_TEST, ), TestCase( "outpaint-vertical-512", @@ -141,7 +148,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=SLOW_TEST, - mse_threshold=0.010, + mse_threshold=LOOSE_TEST, ), TestCase( "outpaint-horizontal-512", @@ -152,7 +159,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=SLOW_TEST, - mse_threshold=0.010, + mse_threshold=LOOSE_TEST, ), TestCase( "upscale-resrgan-x2-1024-muffin", @@ -229,7 +236,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=VERY_SLOW_TEST, - mse_threshold=0.025, + mse_threshold=VERY_LOOSE_TEST, ), TestCase( "outpaint-panorama-vertical-512", @@ -240,7 +247,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=VERY_SLOW_TEST, - mse_threshold=0.025, + mse_threshold=VERY_LOOSE_TEST, ), TestCase( "outpaint-panorama-horizontal-512", @@ -251,7 +258,7 @@ TEST_DATA = [ source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", max_attempts=VERY_SLOW_TEST, - mse_threshold=0.025, + mse_threshold=VERY_LOOSE_TEST, ), TestCase( "upscale-resrgan-x4-codeformer-2048-muffin", @@ -260,6 +267,7 @@ TEST_DATA = [ "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0" ), source="txt2img-sd-v1-5-512-muffin-0", + max_attempts=SLOW_TEST, ), TestCase( "upscale-resrgan-x4-gfpgan-2048-muffin", @@ -268,6 +276,7 @@ TEST_DATA = [ "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0" ), source="txt2img-sd-v1-5-512-muffin-0", + max_attempts=SLOW_TEST, ), TestCase( "upscale-swinir-x4-codeformer-2048-muffin", @@ -276,6 +285,7 @@ TEST_DATA = [ "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0" ), source="txt2img-sd-v1-5-512-muffin-0", + max_attempts=SLOW_TEST, ), TestCase( "upscale-swinir-x4-gfpgan-2048-muffin", @@ -284,6 +294,7 @@ TEST_DATA = [ "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0" ), source="txt2img-sd-v1-5-512-muffin-0", + max_attempts=SLOW_TEST, ), TestCase( "upscale-sd-x4-codeformer-2048-muffin", @@ -305,18 +316,18 @@ TEST_DATA = [ ), TestCase( "txt2img-panorama-1024x768-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true", max_attempts=VERY_SLOW_TEST, ), TestCase( "img2img-panorama-1024x768-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true", source="txt2img-panorama-1024x768-muffin-0", max_attempts=VERY_SLOW_TEST, ), TestCase( "txt2img-sd-v1-5-tall-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768", ), TestCase( "upscale-resrgan-x4-tall-muffin", @@ -325,6 +336,7 @@ TEST_DATA = [ "&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0" ), source="txt2img-sd-v1-5-tall-muffin-0", + max_attempts=SLOW_TEST, ), # TODO: non-square controlnet ] @@ -335,6 +347,39 @@ class TestError(Exception): return super().__str__() +class TestResult: + error: Optional[str] + mse: Optional[float] + name: str + passed: bool + + def __init__(self, name: str, error = None, passed = True, mse = None) -> None: + self.error = error + self.mse = mse + self.name = name + self.passed = passed + + def __repr__(self) -> str: + if self.passed: + if self.mse is not None: + return f"{self.name} ({self.mse})" + else: + return self.name + else: + if self.mse is not None: + return f"{self.name}: {self.error} ({self.mse})" + else: + return f"{self.name}: {self.error}" + + @classmethod + def passed(self, name: str, mse = None): + return TestResult(name, mse=mse) + + @classmethod + def failed(self, name: str, error: str, mse = None): + return TestResult(name, error=error, mse=mse, passed=False) + + def parse_args(args: List[str]): parser = ArgumentParser( prog="onnx-web release tests", @@ -441,14 +486,14 @@ def run_test( host: str, test: TestCase, mse_mult: float = 1.0, -) -> bool: +) -> TestResult: """ Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ keys = generate_images(host, test) if keys is None: - raise ValueError("could not generate image") + return TestResult.failed(test.name, "could not generate image") ready = False for attempt in tqdm(range(test.max_attempts)): @@ -461,13 +506,13 @@ def run_test( sleep(6) if not ready: - raise ValueError("image was not ready in time") + return TestResult.failed(test.name, "image was not ready in time") results = download_images(host, keys) - if results is None: - raise ValueError("could not download image") + if results is None or len(results) == 0: + return TestResult.failed(test.name, "could not download image") - passed = True + passed = False for i in range(len(results)): result = results[i] result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) @@ -476,14 +521,19 @@ def run_test( ref = Image.open(ref_name) if path.exists(ref_name) else None mse = find_mse(result, ref) + threshold = test.mse_threshold * mse_mult - if mse < (test.mse_threshold * mse_mult): - logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold) + if mse < threshold: + logger.info("MSE within threshold: %.5f < %.5f", mse, threshold) + passed = True else: - logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold) - passed = False + logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold) + return TestResult.failed(test.name, error="MSE above threshold", mse=mse) - return passed + if passed: + return TestResult.passed(test.name) + else: + return TestResult.failed(test.name, "no images tested") def main(): @@ -504,24 +554,26 @@ def main(): passed = [] failed = [] for test in tests: - test_passed = False + result = None for _i in range(3): try: logger.info("starting test: %s", test.name) - if run_test(args.host, test, mse_mult=args.mse): + result = run_test(args.host, test, mse_mult=args.mse) + if result.passed: logger.info("test passed: %s", test.name) - test_passed = True break else: logger.warning("test failed: %s", test.name) except Exception: logger.exception("error running test for %s", test.name) + result = TestResult.failed(test.name, "TODO: exception message") - if test_passed: - passed.append(test.name) - else: - failed.append(test.name) + if result is not None: + if result.passed: + passed.append(result) + else: + failed.append(result) logger.info("%s of %s tests passed", len(passed), len(tests)) failed = list(set(failed)) diff --git a/api/tests/chain/__init__.py b/api/tests/chain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/chain/test_base.py b/api/tests/chain/test_base.py new file mode 100644 index 00000000..a0f5463b --- /dev/null +++ b/api/tests/chain/test_base.py @@ -0,0 +1,26 @@ +import unittest + +from onnx_web.chain.pipeline import ChainProgress + + +class ChainProgressTests(unittest.TestCase): + def test_accumulate_with_reset(self): + def parent(step, timestep, latents): + pass + + progress = ChainProgress(parent) + progress(5, 1, None) + progress(0, 1, None) + progress(5, 1, None) + + self.assertEqual(progress.get_total(), 10) + + def test_start_value(self): + def parent(step, timestep, latents): + pass + + progress = ChainProgress(parent, 5) + self.assertEqual(progress.get_total(), 5) + + progress(10, 1, None) + self.assertEqual(progress.get_total(), 10) diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py new file mode 100644 index 00000000..0e6188b1 --- /dev/null +++ b/api/tests/chain/test_blend_grid.py @@ -0,0 +1,23 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_grid import BlendGridStage +from onnx_web.chain.result import StageResult + + +class BlendGridStageTests(unittest.TestCase): + def test_stage(self): + stage = BlendGridStage() + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + ] + ) + result = stage.run(None, None, None, None, sources, height=2, width=2) + + self.assertEqual(len(result), 5) + self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py new file mode 100644 index 00000000..9d6f71d9 --- /dev/null +++ b/api/tests/chain/test_blend_img2img.py @@ -0,0 +1,47 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_img2img import BlendImg2ImgStage +from onnx_web.chain.result import StageResult +from onnx_web.params import ImageParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models + + +class BlendImg2ImgStageTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_stage(self): + stage = BlendImg2ImgStage() + params = ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "euler-a", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ) + server = ServerContext(model_path="../models", output_path="../outputs") + worker = WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 0, + 0.1, + ) + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + ] + ) + result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) + + self.assertEqual(len(result), 1) + self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py new file mode 100644 index 00000000..76a2715a --- /dev/null +++ b/api/tests/chain/test_blend_linear.py @@ -0,0 +1,23 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_linear import BlendLinearStage +from onnx_web.chain.result import StageResult + + +class BlendLinearStageTests(unittest.TestCase): + def test_stage(self): + stage = BlendLinearStage() + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + ] + ) + stage_source = Image.new("RGB", (64, 64), "white") + result = stage.run( + None, None, None, None, sources, alpha=0.5, stage_source=stage_source + ) + + self.assertEqual(len(result), 1) + self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127)) diff --git a/api/tests/chain/test_blend_mask.py b/api/tests/chain/test_blend_mask.py new file mode 100644 index 00000000..4fcb8130 --- /dev/null +++ b/api/tests/chain/test_blend_mask.py @@ -0,0 +1,25 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_mask import BlendMaskStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, UpscaleParams + + +class BlendMaskStageTests(unittest.TestCase): + def test_empty(self): + stage = BlendMaskStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + stage_mask=Image.new("RGBA", (64, 64)), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py new file mode 100644 index 00000000..fa764554 --- /dev/null +++ b/api/tests/chain/test_correct_codeformer.py @@ -0,0 +1,46 @@ +import unittest + +from onnx_web.chain.correct_codeformer import CorrectCodeformerStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.server.hacks import apply_patches +from onnx_web.worker.context import WorkerContext +from tests.helpers import ( + TEST_MODEL_CORRECTION_CODEFORMER, + test_device, + test_needs_models, +) + + +class CorrectCodeformerStageTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER]) + def test_empty(self): + server = ServerContext(model_path="../models", output_path="../outputs") + apply_patches(server) + + worker = WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 0, + 0.1, + ) + stage = CorrectCodeformerStage() + sources = StageResult.empty() + result = stage.run( + worker, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_gfpgan.py b/api/tests/chain/test_correct_gfpgan.py new file mode 100644 index 00000000..9f8b6cb3 --- /dev/null +++ b/api/tests/chain/test_correct_gfpgan.py @@ -0,0 +1,44 @@ +import unittest + +from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.server.hacks import apply_patches +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/correction-gfpgan-v1-3" + + +class CorrectGFPGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + server = ServerContext(model_path="../models", output_path="../outputs") + apply_patches(server) + + worker = WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 0, + 0.1, + ) + stage = CorrectGFPGANStage() + sources = StageResult.empty() + result = stage.run( + worker, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_crop.py b/api/tests/chain/test_reduce_crop.py new file mode 100644 index 00000000..bfc7adc4 --- /dev/null +++ b/api/tests/chain/test_reduce_crop.py @@ -0,0 +1,24 @@ +import unittest + +from onnx_web.chain.reduce_crop import ReduceCropStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class ReduceCropStageTests(unittest.TestCase): + def test_empty(self): + stage = ReduceCropStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_thumbnail.py b/api/tests/chain/test_reduce_thumbnail.py new file mode 100644 index 00000000..8b129672 --- /dev/null +++ b/api/tests/chain/test_reduce_thumbnail.py @@ -0,0 +1,28 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class ReduceThumbnailStageTests(unittest.TestCase): + def test_empty(self): + stage_source = Image.new("RGB", (64, 64)) + stage = ReduceThumbnailStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + stage_source=stage_source, + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py new file mode 100644 index 00000000..37c99bfa --- /dev/null +++ b/api/tests/chain/test_source_noise.py @@ -0,0 +1,26 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.source_noise import SourceNoiseStage +from onnx_web.image.noise_source import noise_source_fill_edge +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceNoiseStageTests(unittest.TestCase): + def test_empty(self): + stage = SourceNoiseStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + noise_source=noise_source_fill_edge, + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py new file mode 100644 index 00000000..59bbb72f --- /dev/null +++ b/api/tests/chain/test_source_s3.py @@ -0,0 +1,26 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.source_s3 import SourceS3Stage +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceS3StageTests(unittest.TestCase): + def test_empty(self): + stage = SourceS3Stage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + bucket="test", + source_keys=[], + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py new file mode 100644 index 00000000..4d03dedb --- /dev/null +++ b/api/tests/chain/test_source_url.py @@ -0,0 +1,25 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.source_url import SourceURLStage +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceURLStageTests(unittest.TestCase): + def test_empty(self): + stage = SourceURLStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + source_urls=[], + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py new file mode 100644 index 00000000..c27cb077 --- /dev/null +++ b/api/tests/chain/test_tile.py @@ -0,0 +1,140 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.result import StageResult +from onnx_web.chain.tile import ( + complete_tile, + generate_tile_grid, + generate_tile_spiral, + make_tile_grads, + needs_tile, + process_tile_stack, +) +from onnx_web.params import Size + + +class TestCompleteTile(unittest.TestCase): + def test_with_complete_tile(self): + partial = Image.new("RGB", (64, 64)) + output = complete_tile(partial, 64) + + self.assertEqual(output.size, (64, 64)) + + def test_with_partial_tile(self): + partial = Image.new("RGB", (64, 32)) + output = complete_tile(partial, 64) + + self.assertEqual(output.size, (64, 64)) + + def test_with_nothing(self): + output = complete_tile(None, 64) + + self.assertIsNone(output) + + +class TestNeedsTile(unittest.TestCase): + def test_with_undersized_source(self): + small = Image.new("RGB", (32, 32)) + + self.assertFalse(needs_tile(64, 64, source=small)) + + def test_with_oversized_source(self): + large = Image.new("RGB", (64, 64)) + + self.assertTrue(needs_tile(32, 32, source=large)) + + def test_with_undersized_size(self): + small = Size(32, 32) + + self.assertFalse(needs_tile(64, 64, size=small)) + + def test_with_oversized_size(self): + large = Size(64, 64) + + self.assertTrue(needs_tile(32, 32, size=large)) + + def test_with_nothing(self): + self.assertFalse(needs_tile(32, 32)) + + +class TestTileGrads(unittest.TestCase): + def test_center_tile(self): + grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64) + + self.assertEqual(grad_x, [0, 1, 1, 0]) + self.assertEqual(grad_y, [0, 1, 1, 0]) + + def test_vertical_edge_tile(self): + grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8) + + self.assertEqual(grad_x, [0, 1, 1, 0]) + self.assertEqual(grad_y, [1, 1, 1, 1]) + + def test_horizontal_edge_tile(self): + grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64) + + self.assertEqual(grad_x, [1, 1, 1, 1]) + self.assertEqual(grad_y, [0, 1, 1, 0]) + + +class TestGenerateTileGrid(unittest.TestCase): + def test_grid_complete(self): + tiles = generate_tile_grid(16, 16, 8, 0.0) + + self.assertEqual(len(tiles), 4) + self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)]) + + def test_grid_no_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.0) + + self.assertEqual(len(tiles), 64) + self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) + self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)]) + + def test_grid_50_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.5) + + self.assertEqual(len(tiles), 256) + self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) + self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)]) + + +class TestGenerateTileSpiral(unittest.TestCase): + def test_spiral_complete(self): + tiles = generate_tile_spiral(16, 16, 8, 0.0) + + self.assertEqual(len(tiles), 4) + self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) + + def test_spiral_no_overlap(self): + tiles = generate_tile_spiral(64, 64, 8, 0.0) + + self.assertEqual(len(tiles), 64) + self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) + self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) + + def test_spiral_50_overlap(self): + tiles = generate_tile_spiral(64, 64, 8, 0.5) + + self.assertEqual(len(tiles), 225) + self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) + self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) + + +class TestProcessTileStack(unittest.TestCase): + def test_grid_full(self): + source = Image.new("RGB", (64, 64)) + blend = process_tile_stack( + StageResult(images=[source]), 32, 1, [], generate_tile_grid + ) + + self.assertEqual(blend[0].size, (64, 64)) + + def test_grid_partial(self): + source = Image.new("RGB", (72, 72)) + blend = process_tile_stack( + StageResult(images=[source]), 32, 1, [], generate_tile_grid + ) + + self.assertEqual(blend[0].size, (72, 72)) diff --git a/api/tests/chain/test_upscale_bsrgan.py b/api/tests/chain/test_upscale_bsrgan.py new file mode 100644 index 00000000..f93b800c --- /dev/null +++ b/api/tests/chain/test_upscale_bsrgan.py @@ -0,0 +1,41 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_bsrgan import UpscaleBSRGANStage +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-bsrgan-x4" + + +class UpscaleBSRGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleBSRGANStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext( + model_path="../models", + ), + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py new file mode 100644 index 00000000..096eea54 --- /dev/null +++ b/api/tests/chain/test_upscale_highres.py @@ -0,0 +1,22 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_highres import UpscaleHighresStage +from onnx_web.params import HighresParams, UpscaleParams + + +class UpscaleHighresStageTests(unittest.TestCase): + def test_empty(self): + stage = UpscaleHighresStage() + sources = StageResult.empty() + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_outpaint.py b/api/tests/chain/test_upscale_outpaint.py new file mode 100644 index 00000000..261a8d45 --- /dev/null +++ b/api/tests/chain/test_upscale_outpaint.py @@ -0,0 +1,50 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_outpaint import UpscaleOutpaintStage +from onnx_web.params import Border, HighresParams, ImageParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_models + + +class UpscaleOutpaintStageTests(unittest.TestCase): + @test_needs_models(["../models/stable-diffusion-onnx-v1-inpainting"]) + def test_empty(self): + stage = UpscaleOutpaintStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext(), + None, + ImageParams( + "../models/stable-diffusion-onnx-v1-inpainting", + "inpaint", + "euler", + "test", + 5.0, + 1, + 1, + ), + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams("stable-diffusion-onnx-v1-inpainting"), + border=Border.even(0), + dims=(), + tile_mask=Image.new("RGB", (64, 64)), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_resrgan.py b/api/tests/chain/test_upscale_resrgan.py new file mode 100644 index 00000000..f832767f --- /dev/null +++ b/api/tests/chain/test_upscale_resrgan.py @@ -0,0 +1,39 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_resrgan import UpscaleRealESRGANStage +from onnx_web.params import HighresParams, StageParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-real-esrgan-x4-v3" + + +class UpscaleRealESRGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleRealESRGANStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext(model_path="../models"), + StageParams(), + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_swinir.py b/api/tests/chain/test_upscale_swinir.py new file mode 100644 index 00000000..dfa9676e --- /dev/null +++ b/api/tests/chain/test_upscale_swinir.py @@ -0,0 +1,39 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_swinir import UpscaleSwinIRStage +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-swinir-real-large-x4" + + +class UpscaleSwinIRStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleSwinIRStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext(), + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/convert/__init__.py b/api/tests/convert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/convert/diffusion/__init__.py b/api/tests/convert/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py new file mode 100644 index 00000000..a39b0b8b --- /dev/null +++ b/api/tests/convert/diffusion/test_lora.py @@ -0,0 +1,357 @@ +import unittest + +import numpy as np +import torch +from onnx import GraphProto, ModelProto, NodeProto +from onnx.numpy_helper import from_array + +from onnx_web.convert.diffusion.lora import ( + blend_node_conv_gemm, + blend_node_matmul, + blend_weights_loha, + blend_weights_lora, + buffer_external_data_tensors, + fix_initializer_name, + fix_node_name, + fix_xl_names, + interp_to_match, + kernel_slice, + sum_weights, +) + + +class SumWeightsTests(unittest.TestCase): + def test_same_shape(self): + weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4)) + + def test_1x1_kernel(self): + weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + + weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + + def test_3x3_kernel(self): + """ + weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + """ + pass + + +class BufferExternalDataTensorTests(unittest.TestCase): + def test_basic_external(self): + model = ModelProto( + graph=GraphProto( + initializer=[ + from_array(np.zeros((4, 4))), + ], + ) + ) + (slim_model, external_weights) = buffer_external_data_tensors(model) + + self.assertEqual( + len(slim_model.graph.initializer), len(model.graph.initializer) + ) + self.assertEqual(len(external_weights), 1) + + +class FixInitializerKeyTests(unittest.TestCase): + def test_fix_name(self): + inputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight" + ] + outputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight" + ] + + for input, output in zip(inputs, outputs): + self.assertEqual(fix_initializer_name(input), output) + + +class FixNodeNameTests(unittest.TestCase): + def test_fix_name(self): + inputs = [ + "lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight", + "_prefix", + ] + outputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight", + "prefix", + ] + + for input, output in zip(inputs, outputs): + self.assertEqual(fix_node_name(input), output) + + +class FixXLNameTests(unittest.TestCase): + def test_empty(self): + nodes = {} + fixed = fix_xl_names(nodes, []) + + self.assertEqual(fixed, {}) + + def test_input_block(self): + nodes = { + "input_block_proj.lora_down.weight": {}, + } + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/down_blocks_proj/MatMul"), + ], + ) + + self.assertEqual( + fixed, + { + "down_blocks_proj": nodes["input_block_proj.lora_down.weight"], + }, + ) + + def test_middle_block(self): + nodes = { + "middle_block_proj.lora_down.weight": {}, + } + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/mid_blocks_proj/MatMul"), + ], + ) + + self.assertEqual( + fixed, + { + "mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"], + }, + ) + + def test_output_block(self): + pass + + def test_text_model(self): + pass + + def test_unknown_block(self): + pass + + def test_unmatched_block(self): + nodes = { + "lora_unet.input_block.lora_down.weight": {}, + } + fixed = fix_xl_names(nodes, [NodeProto(name="test")]) + + self.assertEqual(fixed, nodes) + + def test_output_projection(self): + nodes = { + "output_block_proj_out.lora_down.weight": {}, + } + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/up_blocks_proj_out/MatMul"), + ], + ) + + self.assertEqual( + fixed, + { + "up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"], + }, + ) + + +class KernelSliceTests(unittest.TestCase): + def test_within_kernel(self): + self.assertEqual( + kernel_slice(1, 1, (3, 3, 3, 3)), + (1, 1), + ) + + def test_outside_kernel(self): + self.assertEqual( + kernel_slice(9, 9, (3, 3, 3, 3)), + (2, 2), + ) + + +class InterpToMatchTests(unittest.TestCase): + def test_same_shape(self): + ref = np.zeros((4, 4)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_one_dim(self): + ref = np.zeros((4, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_both_dims(self): + ref = np.zeros((2, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + +class BlendLoRATests(unittest.TestCase): + def test_blend_unet(self): + """ + blend_loras(None, "test", [], "unet") + """ + pass + + def test_blend_text_encoder(self): + """ + blend_loras(None, "test", [], "text_encoder") + """ + pass + + def test_blend_text_encoder_index(self): + """ + blend_loras(None, "test", [], "text_encoder", model_index=2) + """ + pass + + def test_unmatched_keys(self): + pass + + def test_xl_keys(self): + """ + blend_loras(None, "test", [], "unet", xl=True) + """ + pass + + def test_node_dtype(self): + pass + + +class BlendWeightsLoHATests(unittest.TestCase): + def test_blend_t1_t2(self): + # blend einsum: i j k l, j r, i p -> p r k l + i = 32 + j = 4 + k = 1 + l = 1 # NOQA + p = 2 + r = 4 + + model = { + "foo.hada_t1": torch.from_numpy(np.ones((i, j, k, l))), + "foo.hada_t2": torch.from_numpy(np.ones((i, j, k, l))), + "foo.hada_w1_a": torch.from_numpy(np.ones((i, p))), + "foo.hada_w1_b": torch.from_numpy(np.ones((j, r))), + "foo.hada_w2_a": torch.from_numpy(np.ones((i, p))), + "foo.hada_w2_b": torch.from_numpy(np.ones((j, r))), + "foo.alpha": torch.tensor(1), + } + key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (p, r, k, l)) + + def test_blend_w1_w2(self): + model = { + "foo.hada_w1_a": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w1_b": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w2_a": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w2_b": torch.from_numpy(np.ones((1, 4))), + "foo.alpha": torch.tensor(1), + } + key, result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + + def test_blend_no_dim(self): + """ + model = { + "foo.hada_w1_a": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w1_b": torch.from_numpy(np.ones((4, 1))), + "foo.hada_w2_a": torch.from_numpy(np.ones((1, 4))), + "foo.hada_w2_b": torch.from_numpy(np.ones((4, 1))), + } + result = blend_weights_loha("foo.hada_w1_a", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + """ + + +class BlendWeightsLoRATests(unittest.TestCase): + def test_blend_kernel_none(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4)) + + def test_blend_kernel_1x1(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1, 1, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 1, 1)) + + def test_blend_kernel_3x3(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((1, 4, 3, 3))), + "foo.lora_up": torch.from_numpy(np.ones((4, 1, 3, 3))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 3, 3)) + + def test_blend_kernel_3x3_cp_decomp(self): + model = { + "foo.lora_down": torch.from_numpy(np.ones((2, 4, 1, 1))), + "foo.lora_mid": torch.from_numpy(np.ones((2, 2, 3, 3))), + "foo.lora_up": torch.from_numpy(np.ones((4, 2, 1, 1))), + "foo.alpha": 1, + } + key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) + self.assertEqual(result.shape, (4, 4, 3, 3)) + + def test_blend_unknown(self): + pass + + +class BlendNodeConvGemmTests(unittest.TestCase): + def test_blend_kernel_1x1_and_1x1(self): + node = from_array(np.ones((4, 4, 1, 1))) + result = blend_node_conv_gemm(node, np.ones((4, 4, 1, 1))) + + self.assertEqual(result.dims, [4, 4, 1, 1]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_kernel_1x1_and_none(self): + node = from_array(np.ones((4, 4, 1, 1))) + result = blend_node_conv_gemm(node, np.ones((4, 4))) + + self.assertEqual(result.dims, [4, 4, 1, 1]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_other_matching(self): + node = from_array(np.ones((4, 4))) + result = blend_node_conv_gemm(node, np.ones((4, 4))) + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_other_mismatched(self): + pass + + +class BlendNodeMatMulTests(unittest.TestCase): + def test_blend_matching(self): + node = from_array(np.ones((4, 4))) + result = blend_node_matmul(node, np.ones((4, 4)), "test") + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) + + def test_blend_mismatched(self): + node = from_array(np.ones((4, 4))) + result = blend_node_matmul(node, np.ones((2, 2)), "test") + + self.assertEqual(result.dims, [4, 4]) + self.assertEqual(len(result.raw_data), 4 * 4 * 8) diff --git a/api/tests/convert/diffusion/test_textual_inversion.py b/api/tests/convert/diffusion/test_textual_inversion.py new file mode 100644 index 00000000..287907b6 --- /dev/null +++ b/api/tests/convert/diffusion/test_textual_inversion.py @@ -0,0 +1,283 @@ +import unittest + +import numpy as np +import torch +from onnx import GraphProto, ModelProto +from onnx.numpy_helper import from_array, to_array + +from onnx_web.convert.diffusion.textual_inversion import ( + blend_embedding_concept, + blend_embedding_embeddings, + blend_embedding_node, + blend_embedding_parameters, + detect_embedding_format, +) + +TEST_DIMS = (8, 8) +TEST_DIMS_EMBEDS = (1, *TEST_DIMS) + +TEST_MODEL_EMBEDS = { + "string_to_token": { + "test": 1, + }, + "string_to_param": { + "test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, +} + + +class DetectEmbeddingFormatTests(unittest.TestCase): + def test_concept(self): + embedding = { + "": "test", + } + self.assertEqual(detect_embedding_format(embedding), "concept") + + def test_parameters(self): + embedding = { + "emb_params": "test", + } + self.assertEqual(detect_embedding_format(embedding), "parameters") + + def test_embeddings(self): + embedding = { + "string_to_token": "test", + "string_to_param": "test", + } + self.assertEqual(detect_embedding_format(embedding), "embeddings") + + def test_unknown(self): + embedding = { + "what_is_this": "test", + } + self.assertEqual(detect_embedding_format(embedding), None) + + +class BlendEmbeddingConceptTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "": np.ones(TEST_DIMS), + } + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) + + +class BlendEmbeddingParametersTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS_EMBEDS), + } + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + +class BlendEmbeddingEmbeddingsTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + +class BlendEmbeddingNodeTests(unittest.TestCase): + def test_expand_weights(self): + weights = from_array(np.ones(TEST_DIMS)) + weights.name = "text_model.embeddings.token_embedding.weight" + + model = ModelProto( + graph=GraphProto( + initializer=[ + weights, + ] + ) + ) + + embeds = {} + blend_embedding_node( + model, + { + "convert_tokens_to_ids": lambda t: t, + }, + embeds, + 2, + ) + + result = to_array(model.graph.initializer[0]) + + self.assertEqual(len(model.graph.initializer), 1) + self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8) + + +class BlendTextualInversionsTests(unittest.TestCase): + def test_blend_multi_concept(self): + pass + + def test_blend_multi_parameters(self): + pass + + def test_blend_multi_embeddings(self): + pass + + def test_blend_multi_mixed(self): + pass diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py new file mode 100644 index 00000000..4281adbc --- /dev/null +++ b/api/tests/convert/test_utils.py @@ -0,0 +1,241 @@ +import unittest + +from onnx_web.convert.utils import ( + DEFAULT_OPSET, + ConversionContext, + download_progress, + remove_prefix, + resolve_tensor, + source_format, + tuple_to_correction, + tuple_to_diffusion, + tuple_to_source, + tuple_to_upscaling, +) +from tests.helpers import TEST_MODEL_UPSCALING_SWINIR, test_needs_models + + +class ConversionContextTests(unittest.TestCase): + def test_from_environ(self): + context = ConversionContext.from_environ() + self.assertEqual(context.opset, DEFAULT_OPSET) + + def test_map_location(self): + context = ConversionContext.from_environ() + self.assertEqual(context.map_location.type, "cpu") + + +class DownloadProgressTests(unittest.TestCase): + def test_download_example(self): + path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + self.assertEqual(path, "/tmp/example-dot-com") + + +class TupleToSourceTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_source(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_source(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_source(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned as-is with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + +class TupleToCorrectionTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_correction(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_correction(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_correction(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_scale_tuple(self): + source = tuple_to_correction(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_correction(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_correction(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_correction(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) + + +class TupleToDiffusionTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_diffusion(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_diffusion(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_diffusion(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_diffusion(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_single_vae_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_diffusion(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["single_vae"], True) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) + + +class TupleToUpscalingTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_upscaling(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_upscaling(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_upscaling(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_scale_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_upscaling(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) + + +class SourceFormatTests(unittest.TestCase): + def test_with_format(self): + result = source_format( + { + "format": "foo", + } + ) + self.assertEqual(result, "foo") + + def test_source_known_extension(self): + result = source_format( + { + "source": "foo.safetensors", + } + ) + self.assertEqual(result, "safetensors") + + def test_source_unknown_extension(self): + result = source_format({"source": "foo.none"}) + self.assertEqual(result, None) + + def test_incomplete_model(self): + self.assertIsNone(source_format({})) + + +class RemovePrefixTests(unittest.TestCase): + def test_with_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar") + + def test_without_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") + + +class LoadTorchTests(unittest.TestCase): + pass + + +class LoadTensorTests(unittest.TestCase): + pass + + +class ResolveTensorTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) + def test_resolve_existing(self): + self.assertEqual( + resolve_tensor("../models/.cache/upscaling-swinir"), + TEST_MODEL_UPSCALING_SWINIR, + ) + + def test_resolve_missing(self): + self.assertIsNone(resolve_tensor("missing")) diff --git a/api/tests/helpers.py b/api/tests/helpers.py new file mode 100644 index 00000000..f0c10edd --- /dev/null +++ b/api/tests/helpers.py @@ -0,0 +1,51 @@ +from multiprocessing import Queue, Value +from os import path +from typing import List +from unittest import skipUnless + +from onnx_web.params import DeviceParams +from onnx_web.worker.context import WorkerContext + + +def test_needs_models(models: List[str]): + return skipUnless( + all([path.exists(model) for model in models]), "model does not exist" + ) + + +def test_needs_onnx_models(models: List[str]): + return skipUnless( + all([path.exists(f"{model}.onnx") for model in models]), "model does not exist" + ) + + +def test_device() -> DeviceParams: + return DeviceParams("cpu", "CPUExecutionProvider") + + +def test_worker() -> WorkerContext: + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + return WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + + +TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth" +TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" +TEST_MODEL_DIFFUSION_SD15_INPAINT = "../models/stable-diffusion-onnx-v1-inpainting" +TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth" diff --git a/api/tests/image/__init__.py b/api/tests/image/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/image/test_mask_filter.py b/api/tests/image/test_mask_filter.py new file mode 100644 index 00000000..e36470e4 --- /dev/null +++ b/api/tests/image/test_mask_filter.py @@ -0,0 +1,33 @@ +import unittest + +from PIL import Image + +from onnx_web.image.mask_filter import ( + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, +) + + +class MaskFilterNoneTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_none(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) + + +class MaskFilterGaussianMultiplyTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_multiply(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) + + +class MaskFilterGaussianScreenTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_screen(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) diff --git a/api/tests/image/test_source_filter.py b/api/tests/image/test_source_filter.py new file mode 100644 index 00000000..fb44073e --- /dev/null +++ b/api/tests/image/test_source_filter.py @@ -0,0 +1,37 @@ +import unittest + +from PIL import Image + +from onnx_web.image.source_filter import ( + source_filter_gaussian, + source_filter_noise, + source_filter_none, +) +from onnx_web.server.context import ServerContext + + +class SourceFilterNoneTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_none(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterGaussianTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_gaussian(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterNoiseTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_noise(server, source) + self.assertEqual(result.size, dims) diff --git a/api/tests/image/test_utils.py b/api/tests/image/test_utils.py new file mode 100644 index 00000000..215bb10b --- /dev/null +++ b/api/tests/image/test_utils.py @@ -0,0 +1,24 @@ +import unittest + +from PIL import Image + +from onnx_web.image.utils import expand_image +from onnx_web.params import Border + + +class ExpandImageTests(unittest.TestCase): + def test_expand(self): + result = expand_image( + Image.new("RGB", (8, 8)), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].size, (16, 16)) + + def test_masked(self): + result = expand_image( + Image.new("RGB", (8, 8), "red"), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0)) diff --git a/api/tests/mocks.py b/api/tests/mocks.py new file mode 100644 index 00000000..ef95d754 --- /dev/null +++ b/api/tests/mocks.py @@ -0,0 +1,43 @@ +from typing import Any, Optional + + +class MockPipeline: + # flags + slice_size: Optional[str] + vae_slicing: Optional[bool] + sequential_offload: Optional[bool] + model_offload: Optional[bool] + xformers: Optional[bool] + + # stubs + _encode_prompt: Optional[Any] + unet: Optional[Any] + vae_decoder: Optional[Any] + vae_encoder: Optional[Any] + + def __init__(self) -> None: + self.slice_size = None + self.vae_slicing = None + self.sequential_offload = None + self.model_offload = None + self.xformers = None + + self._encode_prompt = None + self.unet = None + self.vae_decoder = None + self.vae_encoder = None + + def enable_attention_slicing(self, slice_size: str = None): + self.slice_size = slice_size + + def enable_vae_slicing(self): + self.vae_slicing = True + + def enable_sequential_cpu_offload(self): + self.sequential_offload = True + + def enable_model_cpu_offload(self): + self.model_offload = True + + def enable_xformers_memory_efficient_attention(self): + self.xformers = True diff --git a/api/tests/models/__init__.py b/api/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/models/test_meta.py b/api/tests/models/test_meta.py new file mode 100644 index 00000000..458c8c37 --- /dev/null +++ b/api/tests/models/test_meta.py @@ -0,0 +1,12 @@ +import unittest + +from onnx_web.models.meta import NetworkModel + + +class NetworkModelTests(unittest.TestCase): + def test_json(self): + model = NetworkModel("test", "inversion") + json = model.tojson() + + self.assertIn("name", json) + self.assertIn("type", json) diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py index 20c03341..15c91d6c 100644 --- a/api/tests/prompt/test_parser.py +++ b/api/tests/prompt/test_parser.py @@ -1,7 +1,9 @@ import unittest + from onnx_web.prompt.grammar import PromptPhrase from onnx_web.prompt.parser import parse_prompt_onnx + class ParserTests(unittest.TestCase): def test_single_word_phrase(self): res = parse_prompt_onnx(None, "foo (bar) bin", debug=False) @@ -11,7 +13,7 @@ class ParserTests(unittest.TestCase): str(["foo"]), str(PromptPhrase(["bar"], weight=1.5)), str(["bin"]), - ] + ], ) def test_multi_word_phrase(self): @@ -22,7 +24,7 @@ class ParserTests(unittest.TestCase): str(["foo", "bar"]), str(PromptPhrase(["middle", "words"], weight=1.5)), str(["bin", "bun"]), - ] + ], ) def test_nested_phrase(self): @@ -31,7 +33,7 @@ class ParserTests(unittest.TestCase): [str(i) for i in res], [ str(["foo"]), - str(PromptPhrase(["bar"], weight=(1.5 ** 3))), + str(PromptPhrase(["bar"], weight=(1.5**3))), str(["bin"]), - ] + ], ) diff --git a/api/tests/server/test_load.py b/api/tests/server/test_load.py new file mode 100644 index 00000000..b04df9ef --- /dev/null +++ b/api/tests/server/test_load.py @@ -0,0 +1,110 @@ +import unittest + +from onnx_web.server.context import ServerContext +from onnx_web.server.load import ( + get_available_platforms, + get_config_params, + get_correction_models, + get_diffusion_models, + get_extra_hashes, + get_extra_strings, + get_highres_methods, + get_mask_filters, + get_network_models, + get_noise_sources, + get_source_filters, + get_upscaling_models, + get_wildcard_data, + load_extras, + load_models, +) + + +class ConfigParamTests(unittest.TestCase): + def test_before_setup(self): + params = get_config_params() + self.assertIsNotNone(params) + + +class AvailablePlatformTests(unittest.TestCase): + def test_before_setup(self): + platforms = get_available_platforms() + self.assertIsNotNone(platforms) + + +class CorrectModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_correction_models() + self.assertIsNotNone(models) + + +class DiffusionModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_diffusion_models() + self.assertIsNotNone(models) + + +class NetworkModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_network_models() + self.assertIsNotNone(models) + + +class UpscalingModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_upscaling_models() + self.assertIsNotNone(models) + + +class WildcardDataTests(unittest.TestCase): + def test_before_setup(self): + wildcards = get_wildcard_data() + self.assertIsNotNone(wildcards) + + +class ExtraStringsTests(unittest.TestCase): + def test_before_setup(self): + strings = get_extra_strings() + self.assertIsNotNone(strings) + + +class ExtraHashesTests(unittest.TestCase): + def test_before_setup(self): + hashes = get_extra_hashes() + self.assertIsNotNone(hashes) + + +class HighresMethodTests(unittest.TestCase): + def test_before_setup(self): + methods = get_highres_methods() + self.assertIsNotNone(methods) + + +class MaskFilterTests(unittest.TestCase): + def test_before_setup(self): + filters = get_mask_filters() + self.assertIsNotNone(filters) + + +class NoiseSourceTests(unittest.TestCase): + def test_before_setup(self): + sources = get_noise_sources() + self.assertIsNotNone(sources) + + +class SourceFilterTests(unittest.TestCase): + def test_before_setup(self): + filters = get_source_filters() + self.assertIsNotNone(filters) + + +class LoadExtrasTests(unittest.TestCase): + def test_default_extras(self): + server = ServerContext(extra_models=["../models/extras.json"]) + load_extras(server) + + +class LoadModelsTests(unittest.TestCase): + def test_default_models(self): + server = ServerContext(model_path="../models") + load_models(server) diff --git a/api/tests/server/test_model_cache.py b/api/tests/server/test_model_cache.py index 000065d0..c024b611 100644 --- a/api/tests/server/test_model_cache.py +++ b/api/tests/server/test_model_cache.py @@ -2,33 +2,62 @@ import unittest from onnx_web.server.model_cache import ModelCache -class TestStringMethods(unittest.TestCase): - def test_drop_existing(self): - cache = ModelCache(10) - cache.clear() - cache.set("foo", ("bar",), {}) - self.assertGreater(cache.size, 0) - self.assertEqual(cache.drop("foo", ("bar",)), 1) - def test_drop_missing(self): - cache = ModelCache(10) - cache.clear() - cache.set("foo", ("bar",), {}) - self.assertGreater(cache.size, 0) - self.assertEqual(cache.drop("foo", ("bin",)), 0) +class TestModelCache(unittest.TestCase): + def test_drop_existing(self): + cache = ModelCache(10) + cache.clear() + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bar",)), 1) - def test_get_existing(self): + def test_drop_missing(self): + cache = ModelCache(10) + cache.clear() + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bin",)), 0) + + def test_get_existing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bar",)), value) + + def test_get_missing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bin",)), None) + + """ + def test_set_existing(self): cache = ModelCache(10) cache.clear() - value = {} + cache.set("foo", ("bar",), { + "value": 1, + }) + value = { + "value": 2, + } cache.set("foo", ("bar",), value) - self.assertGreater(cache.size, 0) self.assertIs(cache.get("foo", ("bar",)), value) + """ - def test_get_missing(self): - cache = ModelCache(10) - cache.clear() - value = {} - cache.set("foo", ("bar",), value) - self.assertGreater(cache.size, 0) - self.assertIs(cache.get("foo", ("bin",)), None) + def test_set_missing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertIs(cache.get("foo", ("bar",)), value) + + def test_set_zero(self): + cache = ModelCache(0) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertEqual(cache.size, 0) diff --git a/api/tests/test_diffusers/__init__.py b/api/tests/test_diffusers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py new file mode 100644 index 00000000..014f7aa0 --- /dev/null +++ b/api/tests/test_diffusers/test_load.py @@ -0,0 +1,330 @@ +import unittest +from os import path + +import torch +from diffusers import DDIMScheduler + +from onnx_web.diffusers.load import ( + get_available_pipelines, + get_pipeline_schedulers, + get_scheduler_name, + load_controlnet, + load_text_encoders, + load_unet, + load_vae, + optimize_pipeline, + patch_pipeline, +) +from onnx_web.diffusers.patches.unet import UNetWrapper +from onnx_web.diffusers.patches.vae import VAEWrapper +from onnx_web.models.meta import NetworkModel +from onnx_web.params import DeviceParams, ImageParams +from onnx_web.server.context import ServerContext +from tests.mocks import MockPipeline + + +class TestAvailablePipelines(unittest.TestCase): + def test_available_pipelines(self): + pipelines = get_available_pipelines() + + self.assertIn("txt2img", pipelines) + + +class TestPipelineSchedulers(unittest.TestCase): + def test_pipeline_schedulers(self): + schedulers = get_pipeline_schedulers() + + self.assertIn("euler-a", schedulers) + + +class TestSchedulerNames(unittest.TestCase): + def test_valid_name(self): + scheduler = get_scheduler_name(DDIMScheduler) + + self.assertEqual("ddim", scheduler) + + def test_missing_names(self): + self.assertIsNone(get_scheduler_name("test")) + + +class TestOptimizePipeline(unittest.TestCase): + def test_auto_attention_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-auto", + ], + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "auto") + + def test_max_attention_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-max", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "max") + + def test_vae_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-vae-slicing", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.vae_slicing, True) + + def test_cpu_offload_sequential(self): + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-sequential", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.sequential_offload, True) + + def test_cpu_offload_model(self): + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-model", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.model_offload, True) + + def test_memory_efficient_attention(self): + server = ServerContext( + optimizations=[ + "diffusers-memory-efficient-attention", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.xformers, True) + + +class TestPatchPipeline(unittest.TestCase): + def test_expand_not_lpw(self): + """ + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) + self.assertEqual(pipeline._encode_prompt, expand_prompt) + """ + pass + + def test_unet_wrapper_not_xl(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) + + def test_unet_wrapper_xl(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) + + def test_vae_wrapper(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) + self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) + + +class TestLoadControlNet(unittest.TestCase): + @unittest.skipUnless( + path.exists("../models/control/canny.onnx"), "model does not exist" + ) + def test_load_existing(self): + """ + Should load a model + """ + components = load_controlnet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams( + "test", + "txt2img", + "ddim", + "test", + 1.0, + 10, + 1, + control=NetworkModel("canny", "control"), + ), + ) + self.assertIn("controlnet", components) + + def test_load_missing(self): + """ + Should throw + """ + components = {} + try: + components = load_controlnet( + ServerContext(), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams( + "test", + "txt2img", + "ddim", + "test", + 1.0, + 10, + 1, + control=NetworkModel("missing", "control"), + ), + ) + except Exception: + self.assertNotIn("controlnet", components) + return + + self.fail() + + +class TestLoadTextEncoders(unittest.TestCase): + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), + "model does not exist", + ) + def test_load_embeddings(self): + """ + Should add the token to tokenizer + Should increase the encoder dims + """ + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some embeddings + ], + [], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) + + def test_load_embeddings_xl(self): + pass + + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), + "model does not exist", + ) + def test_load_loras(self): + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [], + [ + # TODO: add some loras + ], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) + + def test_load_loras_xl(self): + pass + + +class TestLoadUnet(unittest.TestCase): + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), + "model does not exist", + ) + def test_load_unet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "unet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) + + def test_load_unet_loras_xl(self): + pass + + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), + "model does not exist", + ) + def test_load_cnet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "cnet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) + + +class TestLoadVae(unittest.TestCase): + @unittest.skipUnless( + path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), + "model does not exist", + ) + def test_load_single(self): + """ + Should return single component + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/upscaling-stable-diffusion-x4", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("vae", components) + self.assertNotIn("vae_decoder", components) + self.assertNotIn("vae_encoder", components) + + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), + "model does not exist", + ) + def test_load_split(self): + """ + Should return split encoder/decoder + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertNotIn("vae", components) + self.assertIn("vae_decoder", components) + self.assertIn("vae_encoder", components) diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py new file mode 100644 index 00000000..26578f3e --- /dev/null +++ b/api/tests/test_diffusers/test_run.py @@ -0,0 +1,425 @@ +import unittest +from multiprocessing import Queue, Value +from os import path + +from PIL import Image + +from onnx_web.diffusers.run import ( + run_blend_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) +from onnx_web.image.mask_filter import mask_filter_none +from onnx_web.image.noise_source import noise_source_uniform +from onnx_web.params import ( + Border, + HighresParams, + ImageParams, + Size, + TileOrder, + UpscaleParams, +) +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import ( + TEST_MODEL_DIFFUSION_SD15, + TEST_MODEL_DIFFUSION_SD15_INPAINT, + test_device, + test_needs_models, + test_worker, +) + +TEST_PROMPT = "an astronaut eating a hamburger" +TEST_SCHEDULER = "ddim" + + +class TestTxt2ImgPipeline(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + ), + Size(256, 256), + ["test-txt2img-basic.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) + output = Image.open("../outputs/test-txt2img-basic.png") + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image + + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_batch(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + batch=2, + ), + Size(256, 256), + ["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) + + output = Image.open("../outputs/test-txt2img-batch-0.png") + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image + + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_highres(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + unet_tile=256, + ), + Size(256, 256), + ["test-txt2img-highres.png"], + UpscaleParams("test", scale=2, outscale=2), + HighresParams(True, 2, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) + output = Image.open("../outputs/test-txt2img-highres.png") + self.assertEqual(output.size, (512, 512)) + + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_highres_batch(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + batch=2, + ), + Size(256, 256), + ["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"], + UpscaleParams("test"), + HighresParams(True, 2, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) + + output = Image.open("../outputs/test-txt2img-highres-batch-0.png") + self.assertEqual(output.size, (512, 512)) + + +class TestImg2ImgPipeline(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + worker = test_worker() + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + run_img2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + ), + ["test-img2img.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + 1.0, + ) + + self.assertTrue(path.exists("../outputs/test-img2img.png")) + + +class TestInpaintPipeline(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) + def test_basic_white(self): + worker = test_worker() + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + mask = Image.new("RGB", (64, 64), "white") + run_inpaint_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15_INPAINT, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + unet_tile=64, + ), + Size(*source.size), + ["test-inpaint-white.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + mask, + Border.even(0), + noise_source_uniform, + mask_filter_none, + "white", + TileOrder.spiral, + False, + 0.0, + ) + + self.assertTrue(path.exists("../outputs/test-inpaint-white.png")) + + @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) + def test_basic_black(self): + worker = test_worker() + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + mask = Image.new("RGB", (64, 64), "black") + run_inpaint_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15_INPAINT, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + unet_tile=64, + ), + Size(*source.size), + ["test-inpaint-black.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + mask, + Border.even(0), + noise_source_uniform, + mask_filter_none, + "black", + TileOrder.spiral, + False, + 0.0, + ) + + self.assertTrue(path.exists("../outputs/test-inpaint-black.png")) + + +class TestUpscalePipeline(unittest.TestCase): + @test_needs_models(["../models/upscaling-stable-diffusion-x4"]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + run_upscale_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + "../models/upscaling-stable-diffusion-x4", + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + ), + Size(256, 256), + ["test-upscale.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + ) + + self.assertTrue(path.exists("../outputs/test-upscale.png")) + + +class TestBlendPipeline(unittest.TestCase): + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + source = Image.new("RGBA", (64, 64), "black") + mask = Image.new("RGBA", (64, 64), "white") + run_blend_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + TEST_SCHEDULER, + TEST_PROMPT, + 3.0, + 1, + 1, + ), + Size(64, 64), + ["test-blend.png"], + UpscaleParams("test"), + [source, source], + mask, + ) + + self.assertTrue(path.exists("../outputs/test-blend.png")) diff --git a/api/tests/test_diffusers/test_utils.py b/api/tests/test_diffusers/test_utils.py new file mode 100644 index 00000000..0e576d8b --- /dev/null +++ b/api/tests/test_diffusers/test_utils.py @@ -0,0 +1,144 @@ +import unittest + +import numpy as np + +from onnx_web.diffusers.utils import ( + expand_alternative_ranges, + expand_interval_ranges, + get_inversions_from_prompt, + get_latents_from_seed, + get_loras_from_prompt, + get_scaled_latents, + get_tile_latents, + pop_random, + slice_prompt, +) +from onnx_web.params import Size + + +class TestExpandIntervalRanges(unittest.TestCase): + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual(prompt, result) + + def test_prompt_with_range(self): + prompt = "an astronaut-{1,4} eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual( + result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger" + ) + + +class TestExpandAlternativeRanges(unittest.TestCase): + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_alternative_ranges(prompt) + self.assertEqual([prompt], result) + + def test_ranges_match(self): + prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)" + result = expand_alternative_ranges(prompt) + self.assertEqual( + result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"] + ) + + +class TestInversionsFromPrompt(unittest.TestCase): + def test_get_inversions(self): + prompt = " an astronaut eating an embedding" + result, tokens = get_inversions_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating an embedding") + self.assertEqual(tokens, [("test", 1.0)]) + + +class TestLoRAsFromPrompt(unittest.TestCase): + def test_get_loras(self): + prompt = " an astronaut eating a LoRA" + result, tokens = get_loras_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating a LoRA") + self.assertEqual(tokens, [("test", 1.0)]) + + +class TestLatentsFromSeed(unittest.TestCase): + def test_batch_size(self): + latents = get_latents_from_seed(1, Size(64, 64), batch=4) + self.assertEqual(latents.shape, (4, 4, 8, 8)) + + def test_consistency(self): + latents1 = get_latents_from_seed(1, Size(64, 64)) + latents2 = get_latents_from_seed(1, Size(64, 64)) + self.assertTrue(np.array_equal(latents1, latents2)) + + +class TestTileLatents(unittest.TestCase): + def test_full_tile(self): + partial = np.zeros((1, 1, 64, 64)) + full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) + self.assertEqual(full.shape, (1, 1, 8, 8)) + + def test_contract_tile(self): + partial = np.zeros((1, 1, 64, 64)) + full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32)) + self.assertEqual(full.shape, (1, 1, 4, 4)) + + def test_expand_tile(self): + partial = np.zeros((1, 1, 32, 32)) + full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) + self.assertEqual(full.shape, (1, 1, 8, 8)) + + +class TestScaledLatents(unittest.TestCase): + def test_scale_up(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=2) + self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0]) + + def test_scale_down(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=0.5) + self.assertEqual( + ( + latents[0, 0, 0, 0] + + latents[0, 0, 0, 1] + + latents[0, 0, 1, 0] + + latents[0, 0, 1, 1] + ) + / 4, + scaled[0, 0, 0, 0], + ) + + +class TestReplaceWildcards(unittest.TestCase): + pass + + +class TestPopRandom(unittest.TestCase): + def test_pop(self): + items = ["1", "2", "3"] + pop_random(items) + self.assertEqual(len(items), 2) + + +class TestRepairNaN(unittest.TestCase): + def test_unchanged(self): + pass + + def test_missing(self): + pass + + +class TestSlicePrompt(unittest.TestCase): + def test_slice_no_delimiter(self): + slice = slice_prompt("foo", 1) + self.assertEqual(slice, "foo") + + def test_slice_within_range(self): + slice = slice_prompt("foo || bar", 1) + self.assertEqual(slice, " bar") + + def test_slice_outside_range(self): + slice = slice_prompt("foo || bar", 9) + self.assertEqual(slice, " bar") diff --git a/api/tests/test_params.py b/api/tests/test_params.py index 0f84cfab..09eb7576 100644 --- a/api/tests/test_params.py +++ b/api/tests/test_params.py @@ -2,6 +2,7 @@ import unittest from onnx_web.params import Border, Size + class BorderTests(unittest.TestCase): def test_json(self): border = Border.even(0) diff --git a/api/tests/test_test.py b/api/tests/test_test.py index c6ee310c..ae70622f 100644 --- a/api/tests/test_test.py +++ b/api/tests/test_test.py @@ -1,5 +1,6 @@ import unittest + # just to get CI happy class ErrorTest(unittest.TestCase): def test(self): diff --git a/api/tests/worker/__init__.py b/api/tests/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py new file mode 100644 index 00000000..ea709156 --- /dev/null +++ b/api/tests/worker/test_pool.py @@ -0,0 +1,136 @@ +import unittest +from multiprocessing import Event +from time import sleep +from typing import Optional + +from onnx_web.params import DeviceParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.pool import DevicePoolExecutor + +TEST_JOIN_TIMEOUT = 0.2 + +lock = Event() + + +def test_job(*args, **kwargs): + lock.wait() + + +def wait_job(*args, **kwargs): + sleep(0.5) + + +class TestWorkerPool(unittest.TestCase): + # lock: Optional[Event] + pool: Optional[DevicePoolExecutor] + + def setUp(self) -> None: + self.pool = None + + def tearDown(self) -> None: + if self.pool is not None: + self.pool.join() + + def test_no_devices(self): + server = ServerContext() + self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + + def test_fake_worker(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + self.assertEqual(len(self.pool.workers), 1) + + def test_cancel_pending(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + + self.pool.submit("test", wait_job, lock=lock) + self.assertEqual(self.pool.done("test"), (True, None)) + + self.assertTrue(self.pool.cancel("test")) + self.assertEqual(self.pool.done("test"), (False, None)) + + def test_cancel_running(self): + pass + + def test_next_device(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + + self.assertEqual(self.pool.get_next_device(), 0) + + def test_needs_device(self): + device1 = DeviceParams("cpu1", "CPUProvider") + device2 = DeviceParams("cpu2", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor( + server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT + ) + self.pool.start() + + self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) + + def test_done_running(self): + """ + TODO: flaky + """ + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start(lock) + self.pool.submit("test", test_job) + sleep(5.0) + + pending, _progress = self.pool.done("test") + self.assertFalse(pending) + + def test_done_pending(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start(lock) + + self.pool.submit("test1", test_job) + self.pool.submit("test2", test_job) + self.assertTrue(self.pool.done("test2"), (True, None)) + + lock.set() + + def test_done_finished(self): + """ + TODO: flaky + """ + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + self.pool.submit("test", wait_job) + self.assertEqual(self.pool.done("test"), (True, None)) + + sleep(5.0) + pending, _progress = self.pool.done("test") + self.assertFalse(pending) + + def test_recycle_live(self): + pass + + def test_recycle_dead(self): + pass + + def test_running_status(self): + pass diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py new file mode 100644 index 00000000..6365fac9 --- /dev/null +++ b/api/tests/worker/test_worker.py @@ -0,0 +1,209 @@ +import unittest +from multiprocessing import Queue, Value +from os import getpid + +from onnx_web.errors import RetryException +from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobCommand +from onnx_web.worker.context import WorkerContext +from onnx_web.worker.worker import ( + EXIT_ERROR, + EXIT_INTERRUPT, + EXIT_MEMORY, + EXIT_REPLACED, + MEMORY_ERRORS, + worker_main, +) +from tests.helpers import test_device + + +def main_memory(_worker): + raise MemoryError(MEMORY_ERRORS[0]) + + +def main_retry(_worker): + raise RetryException() + + +def main_interrupt(_worker): + raise KeyboardInterrupt() + + +class WorkerMainTests(unittest.TestCase): + def test_pending_exception_empty(self): + pass + + def test_pending_exception_interrupt(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_interrupt, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + + self.assertEqual(status, EXIT_INTERRUPT) + + def test_pending_exception_retry(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_retry, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + + self.assertEqual(status, EXIT_ERROR) + + def test_pending_exception_value(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.close() + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + + self.assertEqual(status, EXIT_ERROR) + + def test_pending_exception_other_memory(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + job = JobCommand("test", "test", main_memory, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) + + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + + self.assertEqual(status, EXIT_MEMORY) + + def test_pending_exception_other_unknown(self): + pass + + def test_pending_replaced(self): + status = None + + def exit(exit_status): + nonlocal status + status = exit_status + + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", 0) + idle = Value("L", False) + + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + + self.assertEqual(status, EXIT_REPLACED) diff --git a/common/pipelines/codeformer.json b/common/pipelines/codeformer.json index 1d026a56..5ce967c1 100644 --- a/common/pipelines/codeformer.json +++ b/common/pipelines/codeformer.json @@ -9,7 +9,7 @@ "name": "save-local", "type": "persist-disk", "params": { - "tile_size": "hd8k" + "tiles": "hd8k" } } ] diff --git a/common/pipelines/complex.json b/common/pipelines/complex.json index 547991e0..f878a889 100644 --- a/common/pipelines/complex.json +++ b/common/pipelines/complex.json @@ -23,14 +23,14 @@ "prompt": "a magical wizard in a robe fighting a dragon", "scale": 4, "outscale": 4, - "tile_size": "mini" + "tiles": "mini" } }, { "name": "save-local", "type": "persist-disk", "params": { - "tile_size": "hd8k" + "tiles": "hd8k" } }, { @@ -40,7 +40,7 @@ "bucket": "storage-stable-diffusion", "endpoint_url": "http://scylla.home.holdmyran.ch:8000", "profile_name": "ceph", - "tile_size": "hd8k" + "tiles": "hd8k" } } ] diff --git a/common/pipelines/outpaint.json b/common/pipelines/outpaint.json index cf45aeb1..925de349 100644 --- a/common/pipelines/outpaint.json +++ b/common/pipelines/outpaint.json @@ -20,7 +20,7 @@ "name": "save-local", "type": "persist-disk", "params": { - "tile_size": "hd8k" + "tiles": "hd8k" } } ] diff --git a/docs/chain-pipelines.md b/docs/chain-pipelines.md index 23a8c7ae..3737a18c 100644 --- a/docs/chain-pipelines.md +++ b/docs/chain-pipelines.md @@ -66,7 +66,7 @@ and can also save intermediate output, such as the result of a `source-txt2img` "name": "save-local", "type": "persist-disk", "params": { - "tile_size": "hd8k" + "tiles": "hd8k" } } ] diff --git a/docs/readme-sdxl.png b/docs/readme-sdxl.png new file mode 100644 index 00000000..eb0b2223 --- /dev/null +++ b/docs/readme-sdxl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:145b8a98ecf5cfd4948d5ab17d28b34a8fc63cbb3b2c5e3f94b4411538733a59 +size 1633570 diff --git a/docs/setup-guide.md b/docs/setup-guide.md index 1035b09e..078b3c2a 100644 --- a/docs/setup-guide.md +++ b/docs/setup-guide.md @@ -16,6 +16,7 @@ This guide covers the setup process for onnx-web, including downloading the Wind - [For CPU everywhere: PyTorch CPU and ONNX runtime CPU](#for-cpu-everywhere-pytorch-cpu-and-onnx-runtime-cpu) - [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu) - [Test the models](#test-the-models) + - [Download the web UI bundle](#download-the-web-ui-bundle) - [Windows-specific methods](#windows-specific-methods) - [Windows all-in-one bundle](#windows-all-in-one-bundle) - [Windows Python installer](#windows-python-installer) @@ -215,6 +216,24 @@ If the script works, there will be an image of an astronaut in `outputs/test.png If you get any errors, check [the known errors section of the user guide](user-guide.md#known-errors). +### Download the web UI bundle + +Once the server environment is working, you will need the latest files for the web UI. This is a Javascript bundle and +you can download a pre-built copy from Github or compile your own. + +From [the `gh-pages` branch](https://github.com/ssube/onnx-web/tree/gh-pages), select the version matching your server +and download all three files: + +- `bundle/main.js` +- `config.json` +- `index.html` + +Copy them into your local `api/gui` folder, making sure to keep the `main.js` bundle in the `bundle` subfolder. + +For example, for a v0.11 server, copy the files from https://github.com/ssube/onnx-web/tree/gh-pages/v0.11.0 into your +local copy of https://github.com/ssube/onnx-web/tree/main/api/gui and +https://github.com/ssube/onnx-web/tree/main/api/gui/bundle. + ## Windows-specific methods These methods are specific to Windows, tested on Windows 10, and still experimental. They should provide an easier diff --git a/docs/user-guide.md b/docs/user-guide.md index e5baa12b..3812c87c 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -32,6 +32,9 @@ Please see [the server admin guide](server-admin.md) for details on how to confi - [Prompt tokens](#prompt-tokens) - [LoRA and LyCORIS tokens](#lora-and-lycoris-tokens) - [Textual Inversion tokens](#textual-inversion-tokens) + - [Prompt stages](#prompt-stages) + - [Region tokens](#region-tokens) + - [Reseed tokens (region seeds)](#reseed-tokens-region-seeds) - [CLIP skip tokens](#clip-skip-tokens) - [Long prompt weighting syntax](#long-prompt-weighting-syntax) - [Pipelines](#pipelines) @@ -414,6 +417,71 @@ much less useful. For a concept called `cubex` with the token ``, the avai - `` - `cubex-0` +#### Prompt stages + +You can provide a different prompt for the highres and upscaling stages of an image using prompt stages. Each stage +of a prompt is separated by `||` and can include its own LoRAs, embeddings, and regions. If you are using multiple +iterations of highres, each iteration can have its own prompt stage. This can help you avoid recursive body parts +and some other weird mutations that can be caused by iterating over a subject prompt. + +For example, a prompt like `human being sitting on wet grass, outdoors, bright sunny day` is likely to produce many +small people mixed in with the grass when used with highres. This becomes even worse with 2+ iterations. However, +changing that prompt to `human being sitting on wet grass, outdoors, bright sunny day || outdoors, bright sunny day, detailed, intricate, HDR` +will use the second stage as the prompt for highres: `outdoors, bright sunny day, detailed, intricate, HDR`. + +This allows you to add and refine details, textures, and even the style of the image during the highres pass. + +Prompt stages are only used during upscaling if you are using the Stable Diffusion upscaling model. + +#### Region tokens + +You can use a different prompt for part of the image using `` tokens. Region tokens are more complicated +than the other tokens and have more parameters, which may change in the future. + +```none + +``` + +- `top`, `left`, `bottom`, and `right` define the four corners of a rectangle + - must be integers + - will be rounded down to the nearest multiple of 8 +- `strength` defines the ratio between the two prompts + - must be a float or integer + - strength should be between 0.0 and 100.0 + - 2.0 to 5.0 generally works + - 100.0 completely replaces the base prompt + - < 0 does weird things + - more UNet overlap will require greater strength +- `feather` defines the blending between the two prompts + - must be a float or integer + - this is similar to UNet and VAE overlap + - feather should be between 0.0 and 0.5 + - 0.0 will cause hard edges + - 0.25 is a good default +- the region has its own `prompt` + - any characters _except_ `>` + - if the region prompt ends with `+`, the base prompt will be appended to it + - this can help the region blend with the rest of the image better + - ` autumn forest, detailed background, 4k, HDR` will use two prompts: + - `small dog, autumn forest, detailed background, 4k, HDR` for the region + - `autumn forest, detailed background, 4k, HDR` for the rest of the image + +#### Reseed tokens (region seeds) + +You can use a different seed for part of the image using `` tokens. Reseed tokens will replace the initial +latents in the selected rectangle. There will be some small differences between images due to how the latents +interpreted by the UNet, but the seeded area should be similar to an image of the same size and seed. + +```none + +``` + +- `top`, `left`, `bottom`, and `right` define the four corners of a rectangle + - must be integers + - will be rounded down to the nearest multiple of 8 +- the region has its own `seed` + - must be an integer + #### CLIP skip tokens You can skip the last layers of the CLIP text encoder using the `clip` token: diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 6bed1e46..64fc89dc 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -2,28 +2,31 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; -import { range } from '../utils.js'; import { - ApiClient, - BaseImgParams, - BlendParams, FilterResponse, - HighresParams, ImageResponse, ImageResponseWithRetry, + ModelResponse, + ReadyResponse, + RetryParams, + WriteExtrasResponse, +} from '../types/api.js'; +import { ChainPipeline } from '../types/chain.js'; +import { ExtrasFile } from '../types/model.js'; +import { + BaseImgParams, + BlendParams, + HighresParams, Img2ImgParams, InpaintParams, ModelParams, - ModelResponse, OutpaintParams, - ReadyResponse, - RetryParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams, - WriteExtrasResponse, -} from './types.js'; -import { ExtrasFile } from '../types.js'; +} from '../types/params.js'; +import { range } from '../utils.js'; +import { ApiClient } from './base.js'; /** * Fixed precision for integer parameters. @@ -67,10 +70,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); - url.searchParams.append('tiledVAE', String(params.tiledVAE)); - url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER)); - url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT)); - url.searchParams.append('stride', params.stride.toFixed(FIXED_INTEGER)); + url.searchParams.append('tiled_vae', String(params.tiled_vae)); + url.searchParams.append('unet_overlap', params.unet_overlap.toFixed(FIXED_FLOAT)); + url.searchParams.append('unet_tile', params.unet_tile.toFixed(FIXED_INTEGER)); + url.searchParams.append('vae_overlap', params.vae_overlap.toFixed(FIXED_FLOAT)); + url.searchParams.append('vae_tile', params.vae_tile.toFixed(FIXED_INTEGER)); if (doesExist(params.scheduler)) { url.searchParams.append('scheduler', params.scheduler); @@ -430,6 +434,22 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }; }, + async chain(model: ModelParams, chain: ChainPipeline): Promise { + const url = makeApiUrl(root, 'chain'); + const body = JSON.stringify({ + ...chain, + platform: model.platform, + }); + + // eslint-disable-next-line no-return-await + return await parseRequest(url, { + body, + headers: { + 'Content-Type': 'application/json', + }, + method: 'POST', + }); + }, async ready(key: string): Promise { const path = makeApiUrl(root, 'ready'); path.searchParams.append('output', key); diff --git a/gui/src/client/base.ts b/gui/src/client/base.ts new file mode 100644 index 00000000..70e96706 --- /dev/null +++ b/gui/src/client/base.ts @@ -0,0 +1,110 @@ +import { ServerParams } from '../config.js'; +import { ExtrasFile } from '../types/model.js'; +import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { ChainPipeline } from '../types/chain.js'; +import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js'; + +export interface ApiClient { + extras(): Promise; + + writeExtras(extras: ExtrasFile): Promise; + + /** + * List the available filter masks for inpaint. + */ + filters(): Promise; + + /** + * List the available models. + */ + models(): Promise; + + /** + * List the available noise sources for inpaint. + */ + noises(): Promise>; + + /** + * Get the valid server parameters to validate image parameters. + */ + params(): Promise; + + /** + * Get the available pipelines. + */ + pipelines(): Promise>; + + /** + * Get the available hardware acceleration platforms. + */ + platforms(): Promise>; + + /** + * List the available pipeline schedulers. + */ + schedulers(): Promise>; + + /** + * Load extra strings from the server. + */ + strings(): Promise; + }>>; + + /** + * Start a txt2img pipeline. + */ + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + + /** + * Start an im2img pipeline. + */ + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + + /** + * Start an inpaint pipeline. + */ + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + + /** + * Start an outpaint pipeline. + */ + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + + /** + * Start an upscale pipeline. + */ + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + + /** + * Start a blending pipeline. + */ + blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; + + chain(model: ModelParams, chain: ChainPipeline): Promise; + + /** + * Check whether job has finished and its output is ready. + */ + ready(key: string): Promise; + + /** + * Cancel an existing job. + */ + cancel(key: string): Promise; + + /** + * Retry a previous job using the same parameters. + */ + retry(params: RetryParams): Promise; + + /** + * Restart the image job workers. + */ + restart(): Promise; + + /** + * Check the status of the image job workers. + */ + status(): Promise>; +} diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 97f785a8..561417a3 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -1,6 +1,6 @@ import { BaseError } from 'noicejs'; -import { ApiClient } from './types.js'; +import { ApiClient } from './base.js'; export class NoServerError extends BaseError { constructor() { @@ -39,6 +39,9 @@ export const LOCAL_CLIENT = { async outpaint(model, params, upscale) { throw new NoServerError(); }, + async chain(model, chain) { + throw new NoServerError(); + }, async noises() { throw new NoServerError(); }, diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts deleted file mode 100644 index b9d22495..00000000 --- a/gui/src/client/types.ts +++ /dev/null @@ -1,381 +0,0 @@ -import { ServerParams } from '../config.js'; -import { ExtrasFile } from '../types.js'; - -/** - * Shared parameters for anything using models, which is pretty much everything. - */ -export interface ModelParams { - /** - * The diffusion model to use. - */ - model: string; - - /** - * Specialized pipeline to use. - */ - pipeline: string; - - /** - * The hardware acceleration platform to use. - */ - platform: string; - - /** - * The upscaling model to use. - */ - upscaling: string; - - /** - * The correction model to use. - */ - correction: string; - - /** - * ControlNet to be used. - */ - control: string; -} - -/** - * Shared parameters for most of the image requests. - */ -export interface BaseImgParams { - scheduler: string; - prompt: string; - negativePrompt?: string; - - batch: number; - tiledVAE: boolean; - tiles: number; - overlap: number; - stride: number; - - cfg: number; - steps: number; - seed: number; - eta: number; -} - -/** - * Parameters for txt2img requests. - */ -export interface Txt2ImgParams extends BaseImgParams { - width: number; - height: number; -} - -/** - * Parameters for img2img requests. - */ -export interface Img2ImgParams extends BaseImgParams { - source: Blob; - - loopback: number; - sourceFilter: string; - strength: number; -} - -/** - * Parameters for inpaint requests. - */ -export interface InpaintParams extends BaseImgParams { - mask: Blob; - source: Blob; - - filter: string; - noise: string; - strength: number; - fillColor: string; - tileOrder: string; -} - -/** - * Additional parameters for outpaint border. - * - * @todo should be nested under inpaint/outpaint params - */ -export interface OutpaintPixels { - enabled: boolean; - - left: number; - right: number; - top: number; - bottom: number; -} - -/** - * Parameters for outpaint requests. - */ -export type OutpaintParams = InpaintParams & OutpaintPixels; - -/** - * Additional parameters for the inpaint brush. - * - * These are not currently sent to the server and only stored in state. - * - * @todo move to state - */ -export interface BrushParams { - color: number; - size: number; - strength: number; -} - -/** - * Additional parameters for upscaling. May be sent with most other requests to run a post-pipeline. - */ -export interface UpscaleParams { - enabled: boolean; - upscaleOrder: string; - - denoise: number; - scale: number; - outscale: number; - - faces: boolean; - faceStrength: number; - faceOutscale: number; -} - -/** - * Parameters for upscale requests. - */ -export interface UpscaleReqParams extends BaseImgParams { - source: Blob; -} - -/** - * Parameters for blend requests. - */ -export interface BlendParams { - sources: Array; - mask: Blob; -} - -export interface HighresParams { - enabled: boolean; - - highresIterations: number; - highresMethod: string; - highresScale: number; - highresSteps: number; - highresStrength: number; -} - -/** - * Output image data within the response. - */ -export interface ImageOutput { - key: string; - url: string; -} - -/** - * Output image size, after upscaling and outscale. - */ -export interface ImageSize { - width: number; - height: number; -} - -/** - * General response for most image requests. - */ -export interface ImageResponse { - outputs: Array; - params: Required & Required; - size: ImageSize; -} - -/** - * Status response from the ready endpoint. - */ -export interface ReadyResponse { - cancelled: boolean; - failed: boolean; - progress: number; - ready: boolean; -} - -export interface NetworkModel { - name: string; - type: 'control' | 'inversion' | 'lora'; - // TODO: add token - // TODO: add layer/token count -} - -export interface FilterResponse { - mask: Array; - source: Array; -} - -/** - * List of available models. - */ -export interface ModelResponse { - correction: Array; - diffusion: Array; - networks: Array; - upscaling: Array; -} - -export interface WriteExtrasResponse { - file: string; - successful: Array; - errors: Array; -} - -export type RetryParams = { - type: 'txt2img'; - model: ModelParams; - params: Txt2ImgParams; - upscale?: UpscaleParams; - highres?: HighresParams; -} | { - type: 'img2img'; - model: ModelParams; - params: Img2ImgParams; - upscale?: UpscaleParams; - highres?: HighresParams; -} | { - type: 'inpaint'; - model: ModelParams; - params: InpaintParams; - upscale?: UpscaleParams; - highres?: HighresParams; -} | { - type: 'outpaint'; - model: ModelParams; - params: OutpaintParams; - upscale?: UpscaleParams; - highres?: HighresParams; -} | { - type: 'upscale'; - model: ModelParams; - params: UpscaleReqParams; - upscale?: UpscaleParams; - highres?: HighresParams; -} | { - type: 'blend'; - model: ModelParams; - params: BlendParams; - upscale?: UpscaleParams; -}; - -export interface ImageResponseWithRetry { - image: ImageResponse; - retry: RetryParams; -} - -export interface ImageMetadata { - highres: HighresParams; - outputs: string | Array; - params: Txt2ImgParams | Img2ImgParams | InpaintParams; - upscale: UpscaleParams; - - input_size: ImageSize; - size: ImageSize; -} - -export interface ApiClient { - extras(): Promise; - - writeExtras(extras: ExtrasFile): Promise; - - /** - * List the available filter masks for inpaint. - */ - filters(): Promise; - - /** - * List the available models. - */ - models(): Promise; - - /** - * List the available noise sources for inpaint. - */ - noises(): Promise>; - - /** - * Get the valid server parameters to validate image parameters. - */ - params(): Promise; - - /** - * Get the available pipelines. - */ - pipelines(): Promise>; - - /** - * Get the available hardware acceleration platforms. - */ - platforms(): Promise>; - - /** - * List the available pipeline schedulers. - */ - schedulers(): Promise>; - - /** - * Load extra strings from the server. - */ - strings(): Promise; - }>>; - - /** - * Start a txt2img pipeline. - */ - txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; - - /** - * Start an im2img pipeline. - */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; - - /** - * Start an inpaint pipeline. - */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; - - /** - * Start an outpaint pipeline. - */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; - - /** - * Start an upscale pipeline. - */ - upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; - - /** - * Start a blending pipeline. - */ - blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; - - /** - * Check whether job has finished and its output is ready. - */ - ready(key: string): Promise; - - /** - * Cancel an existing job. - */ - cancel(key: string): Promise; - - /** - * Retry a previous job using the same parameters. - */ - retry(params: RetryParams): Promise; - - /** - * Restart the image job workers. - */ - restart(): Promise; - - /** - * Check the status of the image job workers. - */ - status(): Promise>; -} diff --git a/gui/src/client/utils.ts b/gui/src/client/utils.ts new file mode 100644 index 00000000..967c818d --- /dev/null +++ b/gui/src/client/utils.ts @@ -0,0 +1,158 @@ +import { doesExist } from '@apextoaster/js-utils'; +import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; +import { ChainPipeline, ChainStageParams, STRING_PARAMETERS } from '../types/chain.js'; + +export interface PipelineVariable { + parameter: 'prompt' | 'cfg' | 'seed' | 'steps' | 'eta' | 'scheduler' | 'token'; + value: string; +} + +export interface PipelineGrid { + enabled: boolean; + columns: PipelineVariable; + rows: PipelineVariable; +} + +export const EXPR_STRICT_NUMBER = /^-?\d+$/; +export const EXPR_NUMBER_RANGE = /^(-?\d+)-(-?\d+)$/; + +export const MAX_SEED_SIZE = 32; +export const MAX_SEED = (2**MAX_SEED_SIZE) - 1; + +export function replacePromptTokens(grid: PipelineGrid, params: Txt2ImgParams, columnValue: string | number, rowValue: string | number): {prompt: string} { + const result = { + negativePrompt: params.negativePrompt, + prompt: params.prompt, + }; + + if (grid.columns.parameter === 'token') { + result.prompt = result.prompt.replace('__column__', columnValue.toString()); + + if (doesExist(result.negativePrompt)) { + result.negativePrompt = result.negativePrompt.replace('__column__', columnValue.toString()); + } + } + + if (grid.rows.parameter === 'token') { + result.prompt = result.prompt.replace('__row__', rowValue.toString()); + + if (doesExist(result.negativePrompt)) { + result.negativePrompt = result.negativePrompt.replace('__row__', rowValue.toString()); + } + } + + return result; +} + +export function newSeed(): number { + return Math.floor(Math.random() * MAX_SEED); +} + +export function replaceRandomSeeds(key: string, values: Array): Array { + if (key !== 'seed') { + return values; + } + + return values.map((it) => { + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + if (it === '-1' || it === -1) { + return newSeed(); + } + + return it; + }); +} + +export function rangeSplit(parameter: string, value: string): Array { + const csv = value.split(',').map((it) => it.trim()); + + if (STRING_PARAMETERS.includes(parameter)) { + return csv; + } + + return csv.flatMap((it) => expandRanges(it)); +} + +export function expandRanges(range: string): Array { + if (EXPR_STRICT_NUMBER.test(range)) { + // entirely numeric, return after parsing + const val = parseInt(range, 10); + return [val]; + } + + if (EXPR_NUMBER_RANGE.test(range)) { + const match = EXPR_NUMBER_RANGE.exec(range); + if (doesExist(match)) { + const [_full, startStr, endStr] = Array.from(match); + const start = parseInt(startStr, 10); + const end = parseInt(endStr, 10); + + return new Array(end - start).fill(0).map((_value, idx) => idx + start); + } + } + + return []; +} + +export const GRID_TILE_SIZE = 8192; + +// eslint-disable-next-line max-params +export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { + const pipeline: ChainPipeline = { + defaults: { + ...model, + ...params, + }, + stages: [], + }; + + const tiles: ChainStageParams = { + tiles: GRID_TILE_SIZE, + }; + + const rows = replaceRandomSeeds(grid.rows.parameter, rangeSplit(grid.rows.parameter, grid.rows.value)); + const columns = replaceRandomSeeds(grid.columns.parameter, rangeSplit(grid.columns.parameter, grid.columns.value)); + + let i = 0; + + for (const row of rows) { + for (const column of columns) { + const prompt = replacePromptTokens(grid, params, column, row); + + pipeline.stages.push({ + name: `cell-${i}`, + type: 'source-txt2img', + params: { + ...params, + ...prompt, + ...model, + ...tiles, + [grid.columns.parameter]: column, + [grid.rows.parameter]: row, + }, + }); + + i += 1; + } + } + + pipeline.stages.push({ + name: 'grid', + type: 'blend-grid', + params: { + ...params, + ...model, + ...tiles, + height: rows.length, + width: columns.length, + }, + }); + + pipeline.stages.push({ + name: 'save', + type: 'persist-disk', + params: tiles, + }); + + return pipeline; +} diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index 5ed9de77..baf022b3 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -21,9 +21,10 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BaseImgParams, HighresParams, ImageMetadata, Txt2ImgParams, UpscaleParams } from '../client/types.js'; import { OnnxState, StateContext } from '../state.js'; -import { DeepPartial } from '../types.js'; +import { ImageMetadata } from '../types/api.js'; +import { DeepPartial } from '../types/model.js'; +import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; const { useState } = React; diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index fe683e67..bb3ac6c9 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -1,4 +1,4 @@ -import { mustExist } from '@apextoaster/js-utils'; +import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; import { Delete, Replay } from '@mui/icons-material'; import { Alert, Box, Card, CardContent, IconButton, Tooltip } from '@mui/material'; import { Stack } from '@mui/system'; @@ -9,13 +9,13 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ImageResponse, ReadyResponse, RetryParams } from '../../client/types.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; export interface ErrorCardProps { image: ImageResponse; ready: ReadyResponse; - retry: RetryParams; + retry: Maybe; } export function ErrorCard(props: ErrorCardProps) { @@ -30,8 +30,11 @@ export function ErrorCard(props: ErrorCardProps) { async function retryImage() { removeHistory(image); - const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); - pushHistory(nextImage, nextRetry); + + if (doesExist(retryParams)) { + const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); + pushHistory(nextImage, nextRetry); + } } const retry = useMutation(retryImage); diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index bb1b6e17..7c35acb2 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -8,8 +8,8 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ImageResponse } from '../../client/types.js'; import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ImageResponse } from '../../types/api.js'; import { range, visibleIndex } from '../../utils.js'; export interface ImageCardProps { diff --git a/gui/src/components/card/LoadingCard.tsx b/gui/src/components/card/LoadingCard.tsx index 1f339060..e0fcdb68 100644 --- a/gui/src/components/card/LoadingCard.tsx +++ b/gui/src/components/card/LoadingCard.tsx @@ -8,9 +8,9 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ImageResponse } from '../../client/types.js'; import { POLL_TIME } from '../../config.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ImageResponse } from '../../types/api.js'; const LOADING_PERCENT = 100; const LOADING_OVERAGE = 99; diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index 83b10581..91525b21 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -5,8 +5,8 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { HighresParams } from '../../client/types.js'; import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { HighresParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; export interface HighresControlProps { diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index c0b7b6dc..8271c700 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -1,3 +1,4 @@ +/* eslint-disable camelcase */ import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; import { Casino } from '@mui/icons-material'; import { Button, Checkbox, FormControlLabel, Stack } from '@mui/material'; @@ -9,9 +10,9 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BaseImgParams } from '../../client/types.js'; import { STALE_TIME } from '../../config.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { BaseImgParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; import { PromptInput } from '../input/PromptInput.js'; import { QueryList } from '../input/QueryList.js'; @@ -47,9 +48,6 @@ export function ImageControl(props: ImageControlProps) { staleTime: STALE_TIME, }); - // max stride is the lesser of tile size and server's max stride - const maxStride = Math.min(state.tiles, params.stride.max); - return { + label={t('parameter.unet_tile')} + min={params.unet_tile.min} + max={params.unet_tile.max} + step={params.unet_tile.step} + value={state.unet_tile} + onChange={(unet_tile) => { props.onChange({ ...state, - tiles, + unet_tile, }); }} /> { + label={t('parameter.unet_overlap')} + min={params.unet_overlap.min} + max={params.unet_overlap.max} + step={params.unet_overlap.step} + value={state.unet_overlap} + onChange={(unet_overlap) => { props.onChange({ ...state, - overlap, - }); - }} - /> - { - props.onChange({ - ...state, - stride, + unet_overlap, }); }} /> { props.onChange({ ...state, - tiledVAE: state.tiledVAE === false, + tiled_vae: state.tiled_vae === false, }); }} />} /> + { + props.onChange({ + ...state, + vae_tile, + }); + }} + /> + { + props.onChange({ + ...state, + vae_overlap, + }); + }} + /> PipelineGrid; + setGrid: (grid: Partial) => void; +} + +export type VariableKey = 'prompt' | 'steps' | 'seed'; + +export function VariableControl(props: VariableControlProps) { + const store = mustExist(useContext(StateContext)); + const grid = useStore(store, props.selectGrid); + + const stack = [ + + + props.setGrid({ + enabled: grid.enabled === false, + })} + />} + /> + + , + ]; + + if (grid.enabled) { + stack.push( + + + Columns + + + props.setGrid({ + columns: { + parameter: grid.columns.parameter, + value: event.target.value, + }, + })} /> + , + + + Rows + + + props.setGrid({ + rows: { + parameter: grid.rows.parameter, + value: event.target.value, + } + })} /> + + ); + } + + return {...stack}; +} + +export function parameterList(exclude?: Array) { + const items = []; + + for (const variable of VARIABLE_PARAMETERS) { + if (variable !== 'token' && doesExist(exclude) && exclude.includes(variable)) { + continue; + } + + items.push({variable}); + } + + return items; +} diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index 0bc4673c..ae7f2723 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -5,9 +5,9 @@ import { throttle } from 'lodash'; import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'; import { useTranslation } from 'react-i18next'; -import { BrushParams } from '../../client/types.js'; import { SAVE_TIME } from '../../config.js'; import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; +import { BrushParams } from '../../types/params.js'; import { imageFromBlob } from '../../utils.js'; import { NumericField } from './NumericField'; diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index 2fc52ca3..cdaa6751 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -1,5 +1,5 @@ -import { mustExist } from '@apextoaster/js-utils'; -import { TextField } from '@mui/material'; +import { Maybe, doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; +import { Chip, TextField } from '@mui/material'; import { Stack } from '@mui/system'; import { useQuery } from '@tanstack/react-query'; import * as React from 'react'; @@ -10,8 +10,9 @@ import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { QueryMenu } from '../input/QueryMenu.js'; +import { ModelResponse } from '../../types/api.js'; -const { useContext } = React; +const { useContext, useMemo } = React; /** * @todo replace with a selector @@ -48,26 +49,29 @@ export function PromptInput(props: PromptInputProps) { staleTime: STALE_TIME, }); - const tokens = splitPrompt(prompt); - const groups = Math.ceil(tokens.length / PROMPT_GROUP); - const { t } = useTranslation(); - const helper = t('input.prompt.tokens', { - groups, - tokens: tokens.length, - }); - function addToken(type: string, name: string, weight = 1.0) { + function addNetwork(type: string, name: string, weight = 1.0) { onChange({ prompt: `<${type}:${name}:1.0> ${prompt}`, negativePrompt, }); } + function addToken(name: string) { + onChange({ + prompt: `${prompt}, ${name}`, + }); + } + + const tokens = useMemo(() => { + const networks = extractNetworks(prompt); + return getNetworkTokens(models.data, networks); + }, [prompt, models.data]); + return { @@ -77,6 +81,13 @@ export function PromptInput(props: PromptInputProps) { }); }} /> + + {tokens.map((token) => addToken(token)} + />)} + result.networks.filter((network) => network.type === 'inversion').map((network) => network.name), }} onSelect={(name) => { - addToken('inversion', name); + addNetwork('inversion', name); }} /> result.networks.filter((network) => network.type === 'lora').map((network) => network.name), }} onSelect={(name) => { - addToken('lora', name); + addNetwork('lora', name); }} /> ; } + +export const ANY_TOKEN = /<([^>]+)>/g; + +export type TokenList = Array<[string, number]>; + +export interface PromptNetworks { + inversion: TokenList; + lora: TokenList; +} + +export function extractNetworks(prompt: string): PromptNetworks { + const inversion: TokenList = []; + const lora: TokenList = []; + + for (const token of prompt.matchAll(ANY_TOKEN)) { + const [_whole, match] = Array.from(token); + const [type, name, weight, ..._rest] = match.split(':'); + + switch (type) { + case 'inversion': + inversion.push([name, parseFloat(weight)]); + break; + case 'lora': + lora.push([name, parseFloat(weight)]); + break; + default: + // ignore others + } + } + + return { + inversion, + lora, + }; +} + +// eslint-disable-next-line sonarjs/cognitive-complexity +export function getNetworkTokens(models: Maybe, networks: PromptNetworks): Array { + const tokens: Set = new Set(); + + if (doesExist(models)) { + for (const [name, _weight] of networks.inversion) { + const model = models.networks.find((it) => it.type === 'inversion' && it.name === name); + if (doesExist(model) && model.type === 'inversion') { + tokens.add(model.token); + } + } + + for (const [name, _weight] of networks.lora) { + const model = models.networks.find((it) => it.type === 'lora' && it.name === name); + if (doesExist(model) && model.type === 'lora') { + for (const token of mustDefault(model.tokens, [])) { + tokens.add(token); + } + } + } + } + + return Array.from(tokens).sort(); +} diff --git a/gui/src/components/input/model/CorrectionModel.tsx b/gui/src/components/input/model/CorrectionModel.tsx index a1db14e6..68f31bf3 100644 --- a/gui/src/components/input/model/CorrectionModel.tsx +++ b/gui/src/components/input/model/CorrectionModel.tsx @@ -2,7 +2,7 @@ import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; -import { CorrectionArch, CorrectionModel, ModelFormat } from '../../../types.js'; +import { CorrectionArch, CorrectionModel, ModelFormat } from '../../../types/model.js'; export interface CorrectionModelInputProps { key?: number | string; diff --git a/gui/src/components/input/model/DiffusionModel.tsx b/gui/src/components/input/model/DiffusionModel.tsx index b7d7b98b..e042b4fa 100644 --- a/gui/src/components/input/model/DiffusionModel.tsx +++ b/gui/src/components/input/model/DiffusionModel.tsx @@ -2,7 +2,7 @@ import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; -import { DiffusionModel, ModelFormat } from '../../../types.js'; +import { DiffusionModel, ModelFormat } from '../../../types/model.js'; export interface DiffusionModelInputProps { key?: number | string; diff --git a/gui/src/components/input/model/ExtraNetwork.tsx b/gui/src/components/input/model/ExtraNetwork.tsx index 2fcdef2b..1d279525 100644 --- a/gui/src/components/input/model/ExtraNetwork.tsx +++ b/gui/src/components/input/model/ExtraNetwork.tsx @@ -2,7 +2,7 @@ import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; -import { ExtraNetwork, ModelFormat, NetworkModel, NetworkType } from '../../../types.js'; +import { ExtraNetwork, ModelFormat, NetworkModel, NetworkType } from '../../../types/model.js'; export interface ExtraNetworkInputProps { key?: number | string; diff --git a/gui/src/components/input/model/ExtraSource.tsx b/gui/src/components/input/model/ExtraSource.tsx index c2011123..5a19ad18 100644 --- a/gui/src/components/input/model/ExtraSource.tsx +++ b/gui/src/components/input/model/ExtraSource.tsx @@ -2,7 +2,7 @@ import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; -import { AnyFormat, ExtraSource } from '../../../types.js'; +import { AnyFormat, ExtraSource } from '../../../types/model.js'; export interface ExtraSourceInputProps { key?: number | string; diff --git a/gui/src/components/input/model/UpscalingModel.tsx b/gui/src/components/input/model/UpscalingModel.tsx index c315da91..cf151df9 100644 --- a/gui/src/components/input/model/UpscalingModel.tsx +++ b/gui/src/components/input/model/UpscalingModel.tsx @@ -2,7 +2,7 @@ import { Button, MenuItem, Select, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; -import { ModelFormat, UpscalingArch, UpscalingModel } from '../../../types.js'; +import { ModelFormat, UpscalingArch, UpscalingModel } from '../../../types/model.js'; import { NumericField } from '../NumericField.js'; export interface UpscalingModelInputProps { diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index 5d75ecd8..a1a304d1 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -7,9 +7,9 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../client/types.js'; import { IMAGE_FILTER } from '../../config.js'; import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { range } from '../../utils.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { ImageInput } from '../input/ImageInput.js'; diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index 306bdd14..cddb8539 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -5,11 +5,12 @@ import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { shallow} from 'zustand/shallow'; +import { shallow } from 'zustand/shallow'; -import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../client/types.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js'; +import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; import { ModelControl } from '../control/ModelControl.js'; @@ -17,7 +18,6 @@ import { UpscaleControl } from '../control/UpscaleControl.js'; import { ImageInput } from '../input/ImageInput.js'; import { NumericField } from '../input/NumericField.js'; import { QueryList } from '../input/QueryList.js'; -import { Profiles } from '../Profiles.js'; export function Img2Img() { const { params } = mustExist(useContext(ConfigContext)); diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index 62fbeb1b..b783f5b5 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -7,9 +7,10 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../client/types.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js'; +import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; import { ModelControl } from '../control/ModelControl.js'; @@ -19,7 +20,6 @@ import { ImageInput } from '../input/ImageInput.js'; import { MaskCanvas } from '../input/MaskCanvas.js'; import { NumericField } from '../input/NumericField.js'; import { QueryList } from '../input/QueryList.js'; -import { Profiles } from '../Profiles.js'; export function Inpaint() { const { params } = mustExist(useContext(ConfigContext)); diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index a0cf13b0..e634485a 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -19,7 +19,7 @@ import { NetworkType, SafetensorFormat, UpscalingModel, -} from '../../types.js'; +} from '../../types/model.js'; import { EditableList } from '../input/EditableList'; import { CorrectionModelInput } from '../input/model/CorrectionModel.js'; import { DiffusionModelInput } from '../input/model/DiffusionModel.js'; diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 76e669ce..921def1f 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -7,23 +7,35 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../client/types.js'; +import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js'; import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js'; +import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; import { ModelControl } from '../control/ModelControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; +import { VariableControl } from '../control/VariableControl.js'; import { NumericField } from '../input/NumericField.js'; -import { Profiles } from '../Profiles.js'; export function Txt2Img() { const { params } = mustExist(useContext(ConfigContext)); async function generateImage() { const state = store.getState(); - const { image, retry } = await client.txt2img(model, selectParams(state), selectUpscale(state), selectHighres(state)); + const grid = selectVariable(state); + const params2 = selectParams(state); + const upscale = selectUpscale(state); + const highres = selectHighres(state); - pushHistory(image, retry); + if (grid.enabled) { + const chain = makeTxt2ImgGridPipeline(grid, model, params2, upscale, highres); + const image = await client.chain(model, chain); + pushHistory(image); + } else { + const { image, retry } = await client.txt2img(model, params2, upscale, highres); + pushHistory(image, retry); + } } const client = mustExist(useContext(ClientContext)); @@ -33,7 +45,7 @@ export function Txt2Img() { }); const store = mustExist(useContext(StateContext)); - const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(store, selectActions, shallow); + const { pushHistory, setHighres, setModel, setParams, setUpscale, setVariable } = useStore(store, selectActions, shallow); const { height, width } = useStore(store, selectReactParams, shallow); const model = useStore(store, selectModel); @@ -79,6 +91,7 @@ export function Txt2Img() { +