1
0
Fork 0

Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards

This commit is contained in:
BZLibby 2023-12-06 16:29:45 -06:00
commit 17f28aba62
182 changed files with 8803 additions and 2788 deletions

View File

@ -17,12 +17,14 @@ with a CPU fallback capable of running on laptop-class machines.
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png) ![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
## Features ## Features
This is an incomplete list of new and interesting features, with links to the user guide: This is an incomplete list of new and interesting features, with links to the user guide:
- SDXL support
- LCM support
- hardware acceleration on both AMD and Nvidia - hardware acceleration on both AMD and Nvidia
- tested on CUDA, DirectML, and ROCm - tested on CUDA, DirectML, and ROCm
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
@ -37,6 +39,7 @@ This is an incomplete list of new and interesting features, with links to the us
- [txt2img](docs/user-guide.md#txt2img-tab) - [txt2img](docs/user-guide.md#txt2img-tab)
- [img2img](docs/user-guide.md#img2img-tab) - [img2img](docs/user-guide.md#img2img-tab)
- [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload - [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload
- [panorama](docs/user-guide.md#panorama-pipeline), for both SD v1.5 and SDXL
- [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration - [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration
- [add and use your own models](docs/user-guide.md#adding-your-own-models) - [add and use your own models](docs/user-guide.md#adding-your-own-models)
- [convert models from diffusers and SD checkpoints](docs/converting-models.md) - [convert models from diffusers and SD checkpoints](docs/converting-models.md)
@ -45,20 +48,24 @@ This is an incomplete list of new and interesting features, with links to the us
- [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks) - [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks)
- [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens) - [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens)
- [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens) - [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens)
- each layer of the embeddings can be controlled and used individually
- ControlNet - ControlNet
- image filters for edge detection and other methods - image filters for edge detection and other methods
- with ONNX acceleration - with ONNX acceleration
- highres mode - highres mode
- runs img2img on the results of the other pipelines - runs img2img on the results of the other pipelines
- multiple iterations can produce 8k images and larger - multiple iterations can produce 8k images and larger
- [multi-stage](docs/user-guide.md#prompt-stages) and [region prompts](docs/user-guide.md#region-tokens)
- seamlessly combine multiple prompts in the same image
- provide prompts for different areas in the image and blend them together
- change the prompt for highres mode and refine details without recursion
- infinite prompt length - infinite prompt length
- [with long prompt weighting](docs/user-guide.md#long-prompt-weighting) - [with long prompt weighting](docs/user-guide.md#long-prompt-weighting)
- expand and control Textual Inversions per-layer
- [image blending mode](docs/user-guide.md#blend-tab) - [image blending mode](docs/user-guide.md#blend-tab)
- combine images from history - combine images from history
- upscaling and face correction - upscaling and correction
- upscaling with Real ESRGAN or Stable Diffusion - upscaling with Real ESRGAN, SwinIR, and Stable Diffusion
- face correction with CodeFormer or GFPGAN - face correction with CodeFormer and GFPGAN
- [API server can be run remotely](docs/server-admin.md) - [API server can be run remotely](docs/server-admin.md)
- REST API can be served over HTTPS or HTTP - REST API can be served over HTTPS or HTTP
- background processing for all image pipelines - background processing for all image pipelines
@ -66,7 +73,7 @@ This is an incomplete list of new and interesting features, with links to the us
- OCI containers provided - OCI containers provided
- for all supported hardware accelerators - for all supported hardware accelerators
- includes both the API and GUI bundle in a single container - includes both the API and GUI bundle in a single container
- runs well on [RunPod](https://www.runpod.io/) and other GPU container hosting services - runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
## Contents ## Contents

11
api/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}

View File

@ -1,4 +1,4 @@
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload .PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload style
onnx_env: ## create virtual env onnx_env: ## create virtual env
python -v venv onnx_env python -v venv onnx_env
@ -18,9 +18,10 @@ pip-dev: check-venv
test: test:
python -m coverage erase python -m coverage erase
python -m coverage run -m unittest discover -s tests/ python -m coverage run -m unittest discover -v -s tests/
python -m coverage html -i python -m coverage html -i
python -m coverage xml -i python -m coverage xml -i
python -m coverage report -i
package: package-dist package-upload package: package-dist package-upload
@ -32,13 +33,21 @@ package-upload:
lint-check: lint-check:
black --check onnx_web/ black --check onnx_web/
isort --check-only --skip __init__.py --filter-files onnx_web black --check tests/
flake8 onnx_web flake8 onnx_web
flake8 tests
isort --check-only --skip __init__.py --filter-files onnx_web
isort --check-only --skip __init__.py --filter-files tests
lint-fix: lint-fix:
black onnx_web/ black onnx_web/
isort --skip __init__.py --filter-files onnx_web black tests/
flake8 onnx_web flake8 onnx_web
flake8 tests
isort --skip __init__.py --filter-files onnx_web
isort --skip __init__.py --filter-files tests
style: lint-fix
typecheck: typecheck:
mypy onnx_web mypy onnx_web

View File

@ -1,45 +1,2 @@
from .base import ChainPipeline, PipelineStage, StageParams from .pipeline import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage from .stages import * # NOQA
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}

View File

@ -1,240 +1,39 @@
from datetime import timedelta from typing import Optional
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from PIL import Image from PIL import Image
from ..errors import RetryException from ..params import ImageParams, Size, SizeChart, StageParams
from ..output import save_image from ..server.context import ServerContext
from ..params import ImageParams, StageParams from ..worker.context import WorkerContext
from ..server import ServerContext from .result import StageResult
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] class BaseStage:
max_tile = SizeChart.auto
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
def run( def run(
self, self,
worker: WorkerContext, _worker: WorkerContext,
server: ServerContext, _server: ServerContext,
params: ImageParams, _stage: StageParams,
sources: List[Image.Image], _params: ImageParams,
callback: Optional[ProgressCallback], _sources: StageResult,
**kwargs *,
) -> List[Image.Image]: stage_source: Optional[Image.Image] = None,
return self( **kwargs,
worker, server, params, sources=sources, callback=callback, **kwargs ) -> StageResult:
) raise NotImplementedError() # noqa
def stage(self, callback: BaseStage, params: StageParams, **kwargs): def steps(
self.stages.append((callback, params, kwargs))
return self
def __call__(
self, self,
worker: WorkerContext, _params: ImageParams,
server: ServerContext, _size: Size,
params: ImageParams, ) -> int:
sources: List[Image.Image], return 1 # noqa
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
start = monotonic() def outputs(
self,
if len(sources) > 0: _params: ImageParams,
logger.info( sources: int,
"running pipeline on %s source images", ) -> int:
len(sources), return sources
)
else:
sources = [None]
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources) - stage_sources.count(None),
kwargs.keys(),
)
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if must_tile:
stage_outputs = []
for source in stage_sources:
logger.info(
"image larger than tile size of %s, tiling stage",
tile,
)
def stage_tile(
source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image:
for i in range(worker.retries):
try:
output_tile = stage_pipe.run(
worker,
server,
stage_params,
params,
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)[0]
if is_debug():
save_image(server, "last-tile.png", output_tile)
return output_tile
except Exception:
logger.exception(
"error while running stage pipeline for tile, retry %s of 3",
i,
)
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
for i in range(worker.retries):
try:
stage_outputs = stage_pipe.run(
worker,
server,
stage_params,
params,
stage_sources,
callback=callback,
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_outputs
break
except Exception:
logger.exception(
"error while running stage pipeline, retry %s of 3", i
)
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources

View File

@ -0,0 +1,40 @@
from logging import getLogger
from typing import Optional
import cv2
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendDenoiseStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
strength: int = 3,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("denoising source images")
results = []
for source in sources.as_numpy():
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
return StageResult(arrays=results)

View File

@ -0,0 +1,62 @@
from logging import getLogger
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendGridStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
height: int,
width: int,
# rows: Optional[List[str]] = None,
# columns: Optional[List[str]] = None,
# title: Optional[str] = None,
order: Optional[List[int]] = None,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("combining source images using grid layout")
images = sources.as_image()
ref_image = images[0]
size = Size(*ref_image.size)
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
# TODO: labels
if order is None:
order = range(len(images))
for i in range(len(order)):
x = i % width
y = i // width
n = order[i]
output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output])
def outputs(
self,
_params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
@ -10,13 +10,14 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class BlendImg2ImgStage(BaseStage): class BlendImg2ImgStage(BaseStage):
max_tile = SizeChart.unlimited max_tile = SizeChart.max
def run( def run(
self, self,
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
strength: float, strength: float,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
prompt_index: Optional[int] = None, prompt_index: Optional[int] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
# multi-stage prompting # multi-stage prompting
@ -52,7 +53,7 @@ class BlendImg2ImgStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
pipe_params["strength"] = strength pipe_params["strength"] = strength
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
if params.is_lpw(): if params.is_lpw():
logger.debug("using LPW pipeline for img2img") logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
@ -81,11 +82,10 @@ class BlendImg2ImgStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
if not params.is_xl(): if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds) pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -102,4 +102,18 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)
def steps(
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,12 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,13 +19,18 @@ class BlendLinearStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
alpha: float, alpha: float,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")
return [Image.blend(source, stage_source, alpha) for source in sources] return StageResult(
images=[
Image.blend(source, stage_source, alpha)
for source in sources.as_image()
]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
@ -8,7 +8,8 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,16 +21,17 @@ class BlendMaskStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("blending image using mask") logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black") # TODO: does this need an alpha channel?
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask) mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L") mult_mask = mult_mask.convert("L")
@ -37,4 +39,9 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
return [Image.composite(stage_source, source, mult_mask) for source in sources] return StageResult(
images=[
Image.composite(stage_source, source, mult_mask)
for source in sources.as_image()
]
)

View File

@ -1,12 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
# TODO: rewrite and remove # TODO: rewrite and remove
from codeformer import CodeFormer from codeformer import CodeFormer
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
device = worker.get_device() device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return [pipe(source) for source in sources] return StageResult(images=[pipe(source) for source in sources.as_image()])

View File

@ -1,15 +1,15 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np
from PIL import Image from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -57,12 +57,12 @@ class CorrectGFPGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.correction_model is None: if upscale.correction_model is None:
@ -73,16 +73,15 @@ class CorrectGFPGANStage(BaseStage):
device = worker.get_device() device = worker.get_device()
gfpgan = self.load(server, stage, upscale, device) gfpgan = self.load(server, stage, upscale, device)
outputs = [] outputs = [
for source in sources: gfpgan.enhance(
output = np.array(source) source,
_, _, output = gfpgan.enhance(
output,
has_aligned=False, has_aligned=False,
only_center_face=False, only_center_face=False,
paste_back=True, paste_back=True,
weight=upscale.face_strength, weight=upscale.face_strength,
) )
outputs.append(Image.fromarray(output, "RGB")) for source in sources.as_numpy()
]
return outputs return StageResult(images=outputs)

View File

@ -1,11 +1,11 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import Optional
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from .pipeline import ChainPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -43,7 +43,7 @@ def stage_highres(
outscale=highres.scale, outscale=highres.scale,
), ),
chain=chain, chain=chain,
overlap=params.overlap, overlap=params.vae_overlap,
) )
else: else:
logger.debug("using simple upscaling for highres") logger.debug("using simple upscaling for highres")
@ -51,14 +51,14 @@ def stage_highres(
UpscaleSimpleStage(), UpscaleSimpleStage(),
stage, stage,
method=highres.method, method=highres.method,
overlap=params.overlap, overlap=params.vae_overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
) )
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, stage.with_args(outscale=1),
overlap=params.overlap, overlap=params.vae_overlap,
prompt_index=prompt_index + i, prompt_index=prompt_index + i,
strength=highres.strength, strength=highres.strength,
) )

View File

@ -1,33 +1,38 @@
from logging import getLogger from logging import getLogger
from typing import List from typing import List, Optional
from PIL import Image from PIL import Image
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class PersistDiskStage(BaseStage): class PersistDiskStage(BaseStage):
max_tile = SizeChart.max
def run( def run(
self, self,
_worker: WorkerContext, _worker: WorkerContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
output: str, output: List[str],
stage_source: Image.Image, size: Optional[Size] = None,
stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
for source in sources: logger.info("persisting %s images to disk: %s", len(sources), output)
# TODO: append index to output name
dest = save_image(server, output, source, params=params) for source, name in zip(sources.as_image(), output):
dest = save_image(server, name, source, params=params, size=size)
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return sources return sources

View File

@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,26 +21,26 @@ class PersistS3Stage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
output: str, output: List[str],
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
for source in sources: for source, name in zip(sources.as_image(), output):
data = BytesIO() data = BytesIO()
source.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)
try: try:
s3.upload_fileobj(data, bucket, output) s3.upload_fileobj(data, bucket, name)
logger.info("saved image to s3://%s/%s", bucket, output) logger.info("saved image to s3://%s/%s", bucket, name)
except Exception: except Exception:
logger.exception("error saving image to S3") logger.exception("error saving image to S3")

View File

@ -0,0 +1,281 @@
from datetime import timedelta
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from PIL import Image
from ..errors import CancelledException, RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
def run(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback],
**kwargs,
) -> List[Image.Image]:
result = self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
return result.as_image()
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def steps(self, params: ImageParams, size: Size) -> int:
steps = 0
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)
return steps
def outputs(self, params: ImageParams, sources: int) -> int:
outputs = sources
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def __call__(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs,
) -> StageResult:
"""
DEPRECATED: use `.run()` instead
"""
if callback is None:
callback = worker.get_progress_callback()
else:
callback = ChainProgress.from_progress(callback)
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources),
kwargs.keys(),
)
per_stage_params = params
if "params" in kwargs:
per_stage_params = kwargs["params"]
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = has_mask(stage_kwargs) or any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources.as_image()
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if must_tile:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
tile,
)
def stage_tile(
source_tile: List[Image.Image],
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> List[Image.Image]:
for _i in range(worker.retries):
try:
tile_result = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
StageResult(images=source_tile),
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)
if is_debug():
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
return tile_result
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled while tiling")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline for tile, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
raise RetryException("exhausted retries on tile")
stage_results = process_tile_order(
stage_params.tile_order,
stage_sources,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_sources = StageResult(images=stage_results)
else:
logger.debug(
"image does not contain sources and is within tile size of %s, running stage",
tile,
)
for _i in range(worker.retries):
try:
stage_result = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
stage_sources,
callback=callback,
dims=(0, 0, tile),
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_result
break
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled during stage")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
for j, image in enumerate(stage_sources.as_image()):
save_image(server, f"last-stage-{j}.png", image)
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
def has_mask(args: List[str]) -> bool:
return any([key in args for key in MASK_KEYS])

View File

@ -1,12 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
image = source.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info( logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height "created thumbnail with dimensions: %sx%s", image.width, image.height
) )
outputs.append(image) outputs.append(image)
return outputs return StageResult(images=outputs)

View File

@ -1,12 +1,12 @@
from logging import getLogger from logging import getLogger
from typing import List
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,15 +18,15 @@ class ReduceThumbnailStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
size: Size, size: Size,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
image = source.copy() image = source.copy()
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
outputs.append(image) outputs.append(image)
return outputs return StageResult(images=outputs)

View File

@ -0,0 +1,73 @@
from typing import List, Optional
import numpy as np
from PIL import Image
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
@staticmethod
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
self.arrays = arrays
self.images = images
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
elif self.images is not None:
return len(self.images)
else:
return 0
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")

View File

@ -1,12 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List from typing import Callable, Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, Size, StageParams from ..params import ImageParams, Size, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,25 +19,34 @@ class SourceNoiseStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Image.Image, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("generating image from noise source") logger.info("generating image from noise source")
if len(sources) > 0: if len(sources) > 0:
logger.warning( logger.info(
"source images were passed to a noise stage and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
outputs = [] outputs = []
for source in sources:
# TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources.as_image():
output = noise_source(source, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,18 +21,23 @@ class SourceS3Stage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
_sources: List[Image.Image], sources: StageResult,
*, *,
source_keys: List[str], source_keys: List[str],
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
outputs = [] if len(sources) > 0:
logger.info(
"source images were passed to a source stage, new images will be appended"
)
outputs = sources.as_image()
for key in source_keys: for key in source_keys:
try: try:
logger.info("loading image from s3://%s/%s", bucket, key) logger.info("loading image from s3://%s/%s", bucket, key)
@ -43,4 +49,11 @@ class SourceS3Stage(BaseStage):
except Exception: except Exception:
logger.exception("error loading image from S3") logger.exception("error loading image from S3")
return outputs return StageResult(outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)

View File

@ -3,26 +3,28 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from ..constants import LATENT_FACTOR
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import ( from ..diffusers.utils import (
encode_prompt, encode_prompt,
get_latents_from_seed, get_latents_from_seed,
get_tile_latents, get_tile_latents,
parse_prompt, parse_prompt,
parse_reseed,
slice_prompt, slice_prompt,
) )
from ..params import ImageParams, Size, SizeChart, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class SourceTxt2ImgStage(BaseStage): class SourceTxt2ImgStage(BaseStage):
max_tile = SizeChart.unlimited max_tile = SizeChart.max
def run( def run(
self, self,
@ -30,15 +32,15 @@ class SourceTxt2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
_source: Image.Image, sources: StageResult,
*, *,
dims: Tuple[int, int, int], dims: Tuple[int, int, int] = None,
size: Size, size: Size,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
prompt_index: Optional[int] = None, prompt_index: Optional[int] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
size = size.with_args(**kwargs) size = size.with_args(**kwargs)
@ -47,31 +49,58 @@ class SourceTxt2ImgStage(BaseStage):
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index)) params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
logger.info( logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt "generating image using txt2img, %s steps of %s: %s",
params.steps,
params.model,
params.prompt,
) )
if "stage_source" in kwargs: if len(sources):
logger.warning( logger.info(
"a source image was passed to a txt2img stage, and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params params
) )
if params.is_xl(): if params.is_panorama() or params.is_xl():
tile_size = max(stage.tile_size, params.tiles) tile_size = max(stage.tile_size, params.unet_tile)
else: else:
tile_size = params.tiles tile_size = params.unet_tile
# this works for panorama as well, because tile_size is already max(tile_size, *size) # this works for panorama as well, because tile_size is already max(tile_size, *size)
latent_size = size.min(tile_size, tile_size) latent_size = size.min(tile_size, tile_size)
# generate new latents or slice existing # generate new latents or slice existing
if latents is None: if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch) latents = get_latents_from_seed(int(params.seed), latent_size, params.batch)
else: else:
latents = get_tile_latents(latents, params.seed, latent_size, dims) latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
# reseed latents as needed
reseed_rng = np.random.RandomState(params.seed)
prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed:
if region_seed == -1:
region_seed = reseed_rng.random_integers(2**32 - 1)
logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
top,
left,
bottom,
right,
region_seed,
)
latents[
:,
:,
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
left // LATENT_FACTOR : right // LATENT_FACTOR,
] = get_latents_from_seed(
region_seed, Size(right - left, bottom - top), params.batch
)
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(
@ -79,7 +108,7 @@ class SourceTxt2ImgStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )
@ -101,11 +130,14 @@ class SourceTxt2ImgStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt( if params.is_panorama() or params.is_xl():
pipe, prompt_pairs, params.batch, params.do_cfg() logger.debug(
) "prompt alternatives are not supported for panorama or SDXL"
)
if not params.is_xl(): else:
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds) pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -123,4 +155,21 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback, callback=callback,
) )
return result.images outputs = sources.as_image()
outputs.extend(result.images)
logger.debug("produced %s outputs", len(outputs))
return StageResult(images=outputs)
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return params.steps
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import List from typing import List, Optional
import requests import requests
from PIL import Image from PIL import Image
@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,20 +21,20 @@ class SourceURLStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
source_urls: List[str], source_urls: List[str],
stage_source: Image.Image, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("loading image from URL source") logger.info("loading image from URL source")
if len(sources) > 0: if len(sources) > 0:
logger.warning( logger.info(
"a source image was passed to a source stage, and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
outputs = [] outputs = sources.as_image()
for url in source_urls: for url in source_urls:
response = requests.get(url) response = requests.get(url)
output = Image.open(BytesIO(response.content)) output = Image.open(BytesIO(response.content))
@ -41,4 +42,11 @@ class SourceURLStage(BaseStage):
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,31 +0,0 @@
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
raise NotImplementedError()
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()

View File

@ -0,0 +1,64 @@
from logging import getLogger
from .base import BaseStage
from .blend_denoise import BlendDenoiseStage
from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__)
CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
def add_stage(name: str, stage: BaseStage) -> bool:
global CHAIN_STAGES
if name in CHAIN_STAGES:
logger.warning("cannot replace stage: %s", name)
return False
else:
CHAIN_STAGES[name] = stage
return True

View File

@ -2,13 +2,14 @@ import itertools
from enum import Enum from enum import Enum
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from typing import List, Optional, Protocol, Tuple from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..image.noise_source import noise_source_histogram from ..image.noise_source import noise_source_histogram
from ..params import Size, TileOrder from ..params import Size, TileOrder
from .result import StageResult
# from skimage.exposure import match_histograms # from skimage.exposure import match_histograms
@ -16,12 +17,15 @@ from ..params import Size, TileOrder
logger = getLogger(__name__) logger = getLogger(__name__)
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
class TileCallback(Protocol): class TileCallback(Protocol):
""" """
Definition for a tile job function. Definition for a tile job function.
""" """
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image: def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
""" """
Run this stage against a single tile. Run this stage against a single tile.
""" """
@ -32,6 +36,9 @@ def complete_tile(
source: Image.Image, source: Image.Image,
tile: int, tile: int,
) -> Image.Image: ) -> Image.Image:
"""
TODO: clean up
"""
if source is None: if source is None:
return source return source
@ -50,6 +57,12 @@ def needs_tile(
source: Optional[Image.Image] = None, source: Optional[Image.Image] = None,
) -> bool: ) -> bool:
tile = min(max_tile, stage_tile) tile = min(max_tile, stage_tile)
logger.trace(
"checking image tile dimensions: %s, %s, %s",
tile,
source.width > tile or source.height > tile if source is not None else False,
size.width > tile or size.height > tile if size is not None else False,
)
if source is not None: if source is not None:
return source.width > tile or source.height > tile return source.width > tile or source.height > tile
@ -60,7 +73,7 @@ def needs_tile(
return False return False
def get_tile_grads( def make_tile_grads(
left: int, left: int,
top: int, top: int,
tile: int, tile: int,
@ -85,6 +98,60 @@ def get_tile_grads(
return (grad_x, grad_y) return (grad_x, grad_y)
def make_tile_mask(
shape: Any,
tile: Tuple[int, int],
overlap: float,
edges: Tuple[bool, bool, bool, bool],
) -> np.ndarray:
mask = np.ones(shape)
tile_h, tile_w = tile
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
# sort gradient points
p1_h = adj_tile_h - 1
p2_h = tile_h - adj_tile_h
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
p1_w = adj_tile_w - 1
p2_w = tile_w - adj_tile_w
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
# build gradients
edge_t, edge_l, edge_b, edge_r = edges
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
int(not edge_t),
1,
1,
int(not edge_b),
]
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
mask = ((mask * mult_x).T * mult_y).T
return mask
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
if isinstance(image, np.ndarray):
return image.shape[-1]
if image.mode == "RGBA":
return 4
elif image.mode == "RGB":
return 3
elif image.mode == "L":
return 1
raise ValueError("unknown image format")
def blend_tiles( def blend_tiles(
tiles: List[Tuple[int, int, Image.Image]], tiles: List[Tuple[int, int, Image.Image]],
scale: int, scale: int,
@ -98,23 +165,24 @@ def blend_tiles(
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
) )
scaled_size = (height * scale, width * scale, 3) channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
scaled_size = (height * scale, width * scale, channels)
count = np.zeros(scaled_size) count = np.zeros(scaled_size)
value = np.zeros(scaled_size) value = np.zeros(scaled_size)
for left, top, tile_image in tiles: for left, top, tile_image in tiles:
# histogram equalization
equalized = np.array(tile_image).astype(np.float32) equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0]) mask = np.ones_like(equalized[:, :, 0])
if adj_tile < tile: if adj_tile < tile:
# sort gradient points # sort gradient points
p1 = adj_tile * scale p1 = (adj_tile * scale) - 1
p2 = (tile - adj_tile) * scale p2 = (tile - adj_tile - 1) * scale
points = [0, min(p1, p2), max(p1, p2), tile * scale] points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
# gradient blending # gradient blending
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height) grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
@ -169,7 +237,7 @@ def blend_tiles(
margin_left : equalized.shape[1] + margin_right, margin_left : equalized.shape[1] + margin_right,
np.newaxis, np.newaxis,
], ],
3, channels,
axis=2, axis=2,
) )
@ -178,60 +246,18 @@ def blend_tiles(
return Image.fromarray(np.uint8(pixels)) return Image.fromarray(np.uint8(pixels))
def process_tile_grid( def process_tile_stack(
source: Image.Image, stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.0,
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * adj_tile
top = y * adj_tile
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = (
source.crop((left, top, left + tile, top + tile)) if source else None
)
tile_image = complete_tile(tile_image, tile)
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
tiles.append((left, top, tile_image))
return blend_tiles(tiles, scale, width, height, tile, overlap)
def process_tile_spiral(
source: Image.Image,
tile: int, tile: int,
scale: int, scale: int,
filters: List[TileCallback], filters: List[TileCallback],
tile_generator: TileGenerator,
overlap: float = 0.5, overlap: float = 0.5,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
width, height = kwargs.get("size", source.size if source else None) sources = stack.as_image()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
mask = kwargs.get("mask", None) mask = kwargs.get("mask", None)
noise_source = kwargs.get("noise_source", noise_source_histogram) noise_source = kwargs.get("noise_source", noise_source_histogram)
fill_color = kwargs.get("fill_color", None) fill_color = kwargs.get("fill_color", None)
@ -239,18 +265,10 @@ def process_tile_spiral(
tile_mask = None tile_mask = None
tiles: List[Tuple[int, int, Image.Image]] = [] tiles: List[Tuple[int, int, Image.Image]] = []
tile_coords = tile_generator(width, height, tile, overlap)
single_tile = len(tile_coords) == 1
# tile tuples is source, multiply by scale for dest for counter, (left, top) in enumerate(tile_coords):
counter = 0
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
if len(tile_coords) == 1:
single_tile = True
else:
single_tile = False
for left, top in tile_coords:
counter += 1
logger.info( logger.info(
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
) )
@ -274,26 +292,36 @@ def process_tile_spiral(
needs_margin = True needs_margin = True
bottom_margin = height - bottom bottom_margin = height - bottom
# if no source given, we don't have a source image if single_tile:
if not source: logger.debug("using single tile")
tile_image = None tile_stack = sources
if mask:
tile_mask = mask
elif needs_margin: elif needs_margin:
# in the special case where the image is smaller than the specified tile size, just use the image logger.debug(
if single_tile: "tiling with added margins: %s, %s, %s, %s",
logger.debug("creating and processing single-tile subtile") left_margin,
tile_image = source top_margin,
if mask: right_margin,
tile_mask = mask bottom_margin,
# otherwise use add histogram noise outside of the image border )
else: tile_stack = add_margin(
logger.debug( stack.as_image(),
"tiling and adding margins: %s, %s, %s, %s", left,
left_margin, top,
top_margin, right,
right_margin, bottom,
bottom_margin, left_margin,
) top_margin,
base_image = source.crop( right_margin,
bottom_margin,
tile,
noise_source,
fill_color,
)
if mask:
base_mask = mask.crop(
( (
left + left_margin, left + left_margin,
top + top_margin, top + top_margin,
@ -301,57 +329,60 @@ def process_tile_spiral(
bottom + bottom_margin, bottom + bottom_margin,
) )
) )
tile_image = noise_source( tile_mask = Image.new("L", (tile, tile), color=0)
base_image, (tile, tile), (0, 0), fill=fill_color tile_mask.paste(base_mask, (left_margin, top_margin))
)
tile_image.paste(base_image, (left_margin, top_margin))
if mask:
base_mask = mask.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_mask = Image.new("L", (tile, tile), color=0)
tile_mask.paste(base_mask, (left_margin, top_margin))
else: else:
logger.debug("tiling normally") logger.debug("tiling normally")
tile_image = source.crop((left, top, right, bottom)) tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
if mask: if mask:
tile_mask = mask.crop((left, top, right, bottom)) tile_mask = mask.crop((left, top, right, bottom))
for image_filter in filters: for image_filter in filters:
tile_image = image_filter(tile_image, tile_mask, (left, top, tile)) tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
tiles.append((left, top, tile_image)) if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
if single_tile: tiles.append((left, top, tile_stack.as_image()))
return tile_image
else: lefts, tops, stacks = list(zip(*tiles))
return blend_tiles(tiles, scale, width, height, tile, overlap) coords = list(zip(lefts, tops))
stacks = list(zip(*stacks))
result = []
for stack in stacks:
stack_tiles = zip(coords, stack)
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
return result
def process_tile_order( def process_tile_order(
order: TileOrder, order: TileOrder,
source: Image.Image, stack: StageResult,
tile: int, tile: int,
scale: int, scale: int,
filters: List[TileCallback], filters: List[TileCallback],
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
"""
TODO: needs to handle more than one image
"""
if order == TileOrder.grid: if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile) logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs) return process_tile_stack(
stack, tile, scale, filters, generate_tile_grid, **kwargs
)
elif order == TileOrder.kernel: elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile) logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError() raise NotImplementedError()
elif order == TileOrder.spiral: elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile) logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs) return process_tile_stack(
stack, tile, scale, filters, generate_tile_spiral, **kwargs
)
else: else:
logger.warning("unknown tile order: %s", order) logger.warning("unknown tile order: %s", order)
raise ValueError() raise ValueError()
@ -445,3 +476,77 @@ def generate_tile_spiral(
height_tile_target -= abs(state.value[1]) height_tile_target -= abs(state.value[1])
return tile_coords return tile_coords
def generate_tile_grid(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
) -> List[Tuple[int, int]]:
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
left = x * adj_tile
top = y * adj_tile
tiles.append((int(left), int(top)))
return tiles
def get_result_tile(
result: StageResult,
origin: Tuple[int, int],
tile: Size,
) -> List[Image.Image]:
top, left = origin
return [
layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image()
]
def add_margin(
stack: List[Image.Image],
left: int,
top: int,
right: int,
bottom: int,
left_margin: int,
top_margin: int,
right_margin: int,
bottom_margin: int,
tile: int,
noise_source,
fill_color,
) -> List[Image.Image]:
results = []
for source in stack:
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin))
results.append(tile_image)
return results

View File

@ -1,22 +1,30 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..models.onnx import OnnxModel from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams from ..params import (
DeviceParams,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class UpscaleBSRGANStage(BaseStage): class UpscaleBSRGANStage(BaseStage):
max_tile = 64 max_tile = SizeChart.micro
def load( def load(
self, self,
@ -54,12 +62,12 @@ class UpscaleBSRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None: if upscale.upscale_model is None:
@ -71,40 +79,38 @@ class UpscaleBSRGANStage(BaseStage):
bsrgan = self.load(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
image = np.array(source) / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape) logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( logger.trace(
"BSRGAN output shape: %s",
( (
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
) ),
) )
logger.trace("BSRGAN output shape: %s", dest.shape)
dest = bsrgan(image) output = bsrgan(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) output = np.clip(np.squeeze(output, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
logger.debug("output image shape: %s", output.shape)
outputs.append(output) outputs.append(output)
return outputs return StageResult(arrays=outputs)
def steps( def steps(
self, self,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
) -> int: ) -> int:
tile = min(params.tiles, self.max_tile) tile = min(params.unet_tile, self.max_tile)
return size.width // tile * size.height // tile return size.width // tile * size.height // tile

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
@ -8,7 +8,8 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from ..worker.context import ProgressCallback from ..worker.context import ProgressCallback
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*args, *,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
if highres.scale <= 1: if highres.scale <= 1:
return sources return sources
chain = stage_highres(stage, params, highres, upscale) chain = stage_highres(stage, params, highres, upscale)
return [ outputs = [
chain( chain(
worker, worker,
server, server,
@ -41,5 +42,7 @@ class UpscaleHighresStage(BaseStage):
source, source,
callback=callback, callback=callback,
) )
for source in sources for source in sources.as_image()
] ]
return StageResult(images=outputs)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -18,13 +18,14 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class UpscaleOutpaintStage(BaseStage): class UpscaleOutpaintStage(BaseStage):
max_tile = SizeChart.unlimited max_tile = SizeChart.max
def run( def run(
self, self,
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
border: Border, border: Border,
dims: Tuple[int, int, int], dims: Tuple[int, int, int],
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params params
) )
@ -56,12 +57,12 @@ class UpscaleOutpaintStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
if is_debug(): if is_debug():
save_image(server, "tile-source.png", source) save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "tile-mask.png", tile_mask)
@ -71,7 +72,7 @@ class UpscaleOutpaintStage(BaseStage):
outputs.append(source) outputs.append(source)
continue continue
tile_size = params.tiles tile_size = params.unet_tile
size = Size(*source.size) size = Size(*source.size)
latent_size = size.min(tile_size, tile_size) latent_size = size.min(tile_size, tile_size)
@ -99,10 +100,11 @@ class UpscaleOutpaintStage(BaseStage):
) )
else: else:
# encode and record alternative prompts outside of LPW # encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt( if not params.is_xl():
pipe, prompt_pairs, params.batch, params.do_cfg() prompt_embeds = encode_prompt(
) pipe, prompt_pairs, params.batch, params.do_cfg()
pipe.unet.set_prompts(prompt_embeds) )
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
@ -121,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)

View File

@ -1,8 +1,7 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np
from PIL import Image from PIL import Image
from ..onnx import OnnxRRDBNet from ..onnx import OnnxRRDBNet
@ -10,7 +9,8 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -77,25 +77,22 @@ class UpscaleRealESRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
upsampler = self.load(
server, upscale, worker.get_device(), tile=stage.tile_size
)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
output = np.array(source) output, _ = upsampler.enhance(source, outscale=upscale.outscale)
upsampler = self.load( logger.info("final output image size: %s", output.shape)
server, upscale, worker.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(arrays=outputs)

View File

@ -1,12 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
method: str, method: str,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
if upscale.scale <= 1: if upscale.scale <= 1:
logger.debug( logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale "simple upscale stage run with scale of %s, skipping", upscale.scale
@ -32,18 +33,20 @@ class UpscaleSimpleStage(BaseStage):
return sources return sources
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
scaled_size = (source.width * upscale.scale, source.height * upscale.scale) scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear": if method == "bilinear":
logger.debug("using bilinear interpolation for highres") logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) outputs.append(
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
)
elif method == "lanczos": elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres") logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) outputs.append(
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
)
else: else:
logger.warning("unknown upscaling method: %s", method) logger.warning("unknown upscaling method: %s", method)
outputs.append(source) return StageResult(images=outputs)
return outputs

View File

@ -1,8 +1,8 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import torch import numpy as np
from PIL import Image from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
@ -10,7 +10,8 @@ from ..diffusers.utils import encode_prompt, parse_prompt
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
logger.info( logger.info(
@ -46,22 +47,23 @@ class UpscaleStableDiffusionStage(BaseStage):
worker.get_device(), worker.get_device(),
model=path.join(server.model_path, upscale.upscale_model), model=path.join(server.model_path, upscale.upscale_model),
) )
generator = torch.manual_seed(params.seed) rng = np.random.RandomState(params.seed)
prompt_embeds = encode_prompt( if not params.is_xl():
pipeline, prompt_embeds = encode_prompt(
prompt_pairs, pipeline,
num_images_per_prompt=params.batch, prompt_pairs,
do_classifier_free_guidance=params.do_cfg(), num_images_per_prompt=params.batch,
) do_classifier_free_guidance=params.do_cfg(),
pipeline.unet.set_prompts(prompt_embeds) )
pipeline.unet.set_prompts(prompt_embeds)
outputs = [] outputs = []
for source in sources: for source in sources.as_image():
result = pipeline( result = pipeline(
prompt, prompt,
source, source,
generator=generator, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
@ -71,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
) )
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)

View File

@ -1,22 +1,23 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import List, Optional from typing import Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..models.onnx import OnnxModel from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext from ..server import ModelTypes, ServerContext
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .stage import BaseStage from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
class UpscaleSwinIRStage(BaseStage): class UpscaleSwinIRStage(BaseStage):
max_tile = 64 max_tile = SizeChart.micro
def load( def load(
self, self,
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image], sources: StageResult,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> StageResult:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None: if upscale.upscale_model is None:
@ -71,31 +72,30 @@ class UpscaleSwinIRStage(BaseStage):
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
outputs = [] outputs = []
for source in sources: for source in sources.as_numpy():
# TODO: add support for grayscale (1-channel) images # TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0 image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape) logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( logger.trace(
"SwinIR output shape: %s",
( (
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
) ),
) )
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image) output = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) output = np.clip(np.squeeze(output, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB") logger.info("output image size: %s", output.shape)
logger.info("output image size: %s x %s", output.width, output.height)
outputs.append(output) outputs.append(output)
return outputs return StageResult(images=outputs)

View File

@ -1,2 +1,5 @@
ONNX_MODEL = "model.onnx" ONNX_MODEL = "model.onnx"
ONNX_WEIGHTS = "weights.pb" ONNX_WEIGHTS = "weights.pb"
LATENT_FACTOR = 8
LATENT_CHANNELS = 4

View File

@ -15,7 +15,8 @@ from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config from ..utils import load_config
from .correction.gfpgan import convert_correction_gfpgan from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.diffusion import convert_diffusion_diffusers
from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
from .diffusion.lora import blend_loras from .diffusion.lora import blend_loras
from .diffusion.textual_inversion import blend_textual_inversions from .diffusion.textual_inversion import blend_textual_inversions
from .upscaling.bsrgan import convert_upscaling_bsrgan from .upscaling.bsrgan import convert_upscaling_bsrgan
@ -357,13 +358,23 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion, name, model["source"], format=model_format conversion, name, model["source"], format=model_format
) )
converted, dest = convert_diffusion_diffusers( pipeline = model.get("pipeline", "txt2img")
conversion, if pipeline.endswith("-sdxl"):
model, converted, dest = convert_diffusion_diffusers_xl(
source, conversion,
model_format, model,
hf=hf, source,
) model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run # make sure blending only happens once, not every run
if converted: if converted:
@ -588,7 +599,7 @@ def main(args=None) -> int:
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
server = ConversionContext.from_environ() server = ConversionContext.from_environ()
server.half = args.half or "onnx-fp16" in server.optimizations server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset server.opset = args.opset
server.token = args.token server.token = args.token
logger.info( logger.info(

View File

@ -20,7 +20,6 @@ from diffusers import (
AutoencoderKL, AutoencoderKL,
OnnxRuntimeModel, OnnxRuntimeModel,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
@ -32,17 +31,25 @@ from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc from ...utils import run_gc
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export from ..utils import (
RESOLVE_FORMATS,
ConversionContext,
check_ext,
is_torch_2_0,
load_tensor,
onnx_export,
)
from .checkpoint import convert_extract_checkpoint from .checkpoint import convert_extract_checkpoint
logger = getLogger(__name__) logger = getLogger(__name__)
available_pipelines = { CONVERT_PIPELINES = {
"controlnet": StableDiffusionControlNetPipeline, "controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline, "img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline, "inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline, "lpw": StableDiffusionPipeline,
@ -96,7 +103,6 @@ def get_model_version(
opts["prediction_type"] = "epsilon" opts["prediction_type"] = "epsilon"
except Exception: except Exception:
logger.debug("unable to load tensor for version check") logger.debug("unable to load tensor for version check")
pass
return (v2, opts) return (v2, opts)
@ -314,7 +320,7 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping") logger.info("ONNX model already exists, skipping")
return (False, dest_path) return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type) pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version( v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version source, conversion.map_location, size=image_size, version=version
) )
@ -360,7 +366,6 @@ def convert_diffusion_diffusers(
source, source,
original_config_file=config_path, original_config_file=config_path,
pipeline_class=pipe_class, pipeline_class=pipe_class,
vae_path=replace_vae,
**pipe_args, **pipe_args,
).to(device, torch_dtype=dtype) ).to(device, torch_dtype=dtype)
elif hf: elif hf:
@ -374,6 +379,17 @@ def convert_diffusion_diffusers(
logger.warning("pipeline source not found or not recognized: %s", source) logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}") raise ValueError(f"pipeline source not found or not recognized: {source}")
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.vae.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipeline) optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path) output_path = Path(dest_path)
@ -424,9 +440,6 @@ def convert_diffusion_diffusers(
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL unet_path = output_path / "unet" / ONNX_MODEL
@ -526,19 +539,6 @@ def convert_diffusion_diffusers(
del unet del unet
run_gc() run_gc()
# VAE
if replace_vae is not None:
if replace_vae.startswith("."):
logger.debug(
"custom VAE appears to be a local path, making it relative to the model path"
)
replace_vae = path.join(conversion.model_path, replace_vae)
logger.info("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
run_gc()
if single_vae: if single_vae:
logger.debug("VAE config: %s", pipeline.vae.config) logger.debug("VAE config: %s", pipeline.vae.config)

View File

@ -0,0 +1,116 @@
from logging import getLogger
from os import path
from typing import Dict, Optional, Tuple
import onnx
import torch
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_diffusers_xl(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
replace_vae = model.get("vae", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_hash = path.join(dest_path, "hash.txt")
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
)
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])
return (False, dest_path)
# safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch")
if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file(
source, use_safetensors=True
)
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
logger.debug("loading VAE from single tensor file: %s", vae_path)
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
logger.debug("loading pretrained VAE from path: %s", vae_path)
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
if path.exists(temp_path):
logger.debug("torch model already exists for %s: %s", source, temp_path)
else:
logger.debug("exporting torch model for %s: %s", source, temp_path)
pipeline.save_pretrained(temp_path)
# directory -> onnx using optimum exporters
main_export(
temp_path,
output=dest_path,
task="stable-diffusion-xl",
device=device,
fp16=conversion.has_optimization(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
)
if "hash" in model:
logger.debug("adding hash file to ONNX model")
with open(model_hash, "w") as f:
f.write(model["hash"])
if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)
unet = onnx.load(unet_path)
opt_model = convert_float_to_float16(
unet,
disable_shape_infer=True,
force_fp16_initializers=True,
keep_io_types=True,
op_block_list=["Attention", "MultiHeadAttention"],
)
onnx.save_model(
opt_model,
unet_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)
return False, dest_path

View File

@ -1,22 +1,15 @@
from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from onnx import ModelProto, load, numpy_helper from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
from onnx.checker import check_model from onnx.external_data_helper import set_external_data
from onnx.external_data_helper import ( from onnxruntime import OrtValue
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from scipy import interpolate from scipy import interpolate
from ...server.context import ServerContext from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor from ..utils import load_tensor
logger = getLogger(__name__) logger = getLogger(__name__)
@ -39,7 +32,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
lr = a lr = a
if kernel == (1, 1): if kernel == (1, 1):
lr = np.expand_dims(lr, axis=(2, 3)) lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
return hr + lr return hr + lr
@ -78,13 +71,15 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
fixed = {} fixed = {}
names = [fix_node_name(node.name) for node in nodes]
for key, value in keys.items(): for key, value in keys.items():
root, *rest = key.split(".") root, *_rest = key.split(".")
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace logger.trace("fixing XL node name: %s -> %s", key, root)
simple = False
if root.startswith("input"): if root.startswith("input"):
block = "down_blocks" block = "down_blocks"
elif root.startswith("middle"): elif root.startswith("middle"):
@ -93,6 +88,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
block = "up_blocks" block = "up_blocks"
elif root.startswith("text_model"): elif root.startswith("text_model"):
block = "text_model" block = "text_model"
elif root.startswith("down_blocks"):
block = "down_blocks"
simple = True
elif root.startswith("mid_block"):
block = "mid_block"
simple = True
elif root.startswith("up_blocks"):
block = "up_blocks"
simple = True
else: else:
logger.warning("unknown XL key name: %s", key) logger.warning("unknown XL key name: %s", key)
fixed[key] = value fixed[key] = value
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
suffix = None suffix = None
for s in [ for s in [
"conv",
"conv_shortcut",
"conv1",
"conv2",
"fc1", "fc1",
"fc2", "fc2",
"ff_net_0_proj", "ff_net_0_proj",
@ -119,18 +127,21 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
logger.warning("new XL key type: %s", root) logger.warning("new XL key type: %s", root)
continue continue
logger.debug("searching for XL node: /%s/*/%s", block, suffix) logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match = None match: Optional[str] = None
if block == "text_model": if "conv" in suffix:
match = next( match = next(node for node in names if node == f"{root}_Conv")
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul" elif "time_emb_proj" in root:
) match = next(node for node in names if node == f"{root}_Gemm")
elif block == "text_model" or simple:
match = next(node for node in names if node == f"{root}_MatMul")
else: else:
# search in order. one side has sparse indices, so they will not match.
match = next( match = next(
node node
for node in nodes for node in names
if node.name.startswith(f"/{block}") if node.startswith(block)
and fix_node_name(node.name).endswith( and node.endswith(
f"{suffix}_MatMul" f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0 ) # needs to be fixed because some places use to_out.0
) )
@ -138,18 +149,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
if match is None: if match is None:
logger.warning("no matches for XL key: %s", root) logger.warning("no matches for XL key: %s", root)
continue continue
else:
logger.trace("matched key: %s -> %s", key, match)
name: str = match.name name = match
name = fix_node_name(name.rstrip("/MatMul")) if name.endswith("_MatMul"):
name = name[:-7]
elif name.endswith("_Gemm"):
name = name[:-5]
elif name.endswith("_Conv"):
name = name[:-5]
if name.endswith("proj_o"): logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
# wtf
name = f"{name}ut"
logger.debug("matching XL key with node: %s -> %s", key, match.name)
fixed[name] = value fixed[name] = value
nodes.remove(match) names.remove(match)
logger.debug(
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
len(fixed.keys()),
len(keys.keys()),
len(names),
)
return fixed return fixed
@ -161,6 +182,245 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
) )
def blend_weights_loha(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
t1_key = key.replace("hada_w1_a", "hada_t1")
t2_key = key.replace("hada_w1_a", "hada_t2")
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
logger.trace(
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
key,
w1b_key,
w2a_key,
w2b_key,
alpha_key,
)
w1a_weight = lora_model[key].to(dtype=dtype)
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
t1_weight = lora_model.get(t1_key, None)
t2_weight = lora_model.get(t2_key, None)
dim = w1b_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
if t1_weight is not None and t2_weight is not None:
t1_weight = t1_weight.to(dtype=dtype)
t2_weight = t2_weight.to(dtype=dtype)
logger.trace(
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
t1_weight.shape,
w1a_weight.shape,
w1b_weight.shape,
t2_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights_1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t1_weight,
w1b_weight,
w1a_weight,
)
weights_2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t2_weight,
w2b_weight,
w2a_weight,
)
weights = weights_1 * weights_2
np_weights = weights.numpy() * (alpha / dim)
else:
logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight.shape,
w1b_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)
return base_key, np_weights
def blend_weights_lora(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)
mid_weight = None
if mid_key in lora_model:
mid_weight = lora_model[mid_key].to(dtype=dtype)
dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key, dim)
if not isinstance(alpha, int):
alpha = alpha.to(dtype).numpy()
kernel = down_weight.shape[-2:]
if mid_weight is not None:
kernel = mid_weight.shape[-2:]
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: (%s @ %s) * %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
1,
1,
):
# blend for nn.Conv2d 1x1
logger.trace(
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = (
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
3,
3,
):
if mid_weight is not None:
# blend for nn.Conv2d 3x3 with CP decomp
logger.trace(
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
down_weight.shape,
up_weight.shape,
mid_weight.shape,
alpha,
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
else:
# blend for nn.Conv2d 3x3
logger.trace(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
)
np_weights = weights.numpy() * (alpha / dim)
else:
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
up_weight.shape[-2:],
)
# TODO: should this be None?
np_weights = np.zeros((1, 1, 1, 1))
return base_key, np_weights
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
def blend_loras( def blend_loras(
_conversion: ServerContext, _conversion: ServerContext,
base_name: Union[str, ModelProto], base_name: Union[str, ModelProto],
@ -184,246 +444,77 @@ def blend_loras(
else: else:
lora_prefix = f"lora_{model_type}_" lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {} layers = []
for (lora_name, lora_weight), lora_model in zip(loras, lora_models): for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight) logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None: if lora_model is None:
logger.warning("unable to load tensor for LoRA") logger.warning("unable to load tensor for LoRA")
continue continue
blended: Dict[str, np.ndarray] = {}
layers.append(blended)
for key in lora_model.keys(): for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key: if ".hada_w1_a" in key and lora_prefix in key:
# LoHA # LoHA
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "") base_key, np_weights = blend_weights_loha(
key, lora_prefix, lora_model, dtype
t1_key = key.replace("hada_w1_a", "hada_t1")
t2_key = key.replace("hada_w1_a", "hada_t2")
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
logger.trace(
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
key,
w1b_key,
w2a_key,
w2b_key,
alpha_key,
) )
np_weights = np_weights * lora_weight
w1a_weight = lora_model[key].to(dtype=dtype) logger.trace(
w1b_weight = lora_model[w1b_key].to(dtype=dtype) "adding LoHA weights: %s",
w2a_weight = lora_model[w2a_key].to(dtype=dtype) np_weights.shape,
w2b_weight = lora_model[w2b_key].to(dtype=dtype) )
blended[base_key] = np_weights
t1_weight = lora_model.get(t1_key, None)
t2_weight = lora_model.get(t2_key, None)
dim = w1b_weight.size()[0]
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
if t1_weight is not None and t2_weight is not None:
t1_weight = t1_weight.to(dtype=dtype)
t2_weight = t2_weight.to(dtype=dtype)
logger.trace(
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
t1_weight.shape,
w1a_weight.shape,
w1b_weight.shape,
t2_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights_1 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t1_weight,
w1b_weight,
w1a_weight,
)
weights_2 = torch.einsum(
"i j k l, j r, i p -> p r k l",
t2_weight,
w2b_weight,
w2a_weight,
)
weights = weights_1 * weights_2
np_weights = weights.numpy() * (alpha / dim)
else:
logger.trace(
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
w1a_weight.shape,
w1b_weight.shape,
w2a_weight.shape,
w2b_weight.shape,
)
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)
np_weights *= lora_weight
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] += sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key: elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON # LoRA or LoCON
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace(
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
) )
np_weights = np_weights * lora_weight
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
down_weight = lora_model[key].to(dtype=dtype) # rewrite node names for XL and flatten layers
up_weight = lora_model[up_key].to(dtype=dtype) weights: Dict[str, np.ndarray] = {}
mid_weight = None for blended in layers:
if mid_key in lora_model: if xl:
mid_weight = lora_model[mid_key].to(dtype=dtype) nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
dim = down_weight.size()[0] for key, value in blended.items():
alpha = lora_model.get(alpha_key, dim) if key in weights:
weights[key] = sum_weights(weights[key], value)
if not isinstance(alpha, int): else:
alpha = alpha.to(dtype).numpy() weights[key] = value
kernel = down_weight.shape[-2:]
if mid_weight is not None:
kernel = mid_weight.shape[-2:]
if len(down_weight.size()) == 2:
# blend for nn.Linear
logger.trace(
"blending weights for Linear node: (%s @ %s) * %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = up_weight @ down_weight
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
1,
1,
):
# blend for nn.Conv2d 1x1
logger.trace(
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(down_weight.size()) == 4 and kernel == (
3,
3,
):
if mid_weight is not None:
# blend for nn.Conv2d 3x3 with CP decomp
logger.trace(
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
down_weight.shape,
up_weight.shape,
mid_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2)
@ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
else:
# blend for nn.Conv2d 3x3
logger.trace(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
for w in range(kernel[0]):
for h in range(kernel[1]):
down_w, down_h = kernel_slice(w, h, down_weight.shape)
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h]
@ down_weight[:, :, down_w, down_h]
)
np_weights = weights.numpy() * (alpha / dim)
else:
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
up_weight.shape[-2:],
)
continue
np_weights *= lora_weight
if base_key in blended:
logger.trace(
"summing weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
# rewrite node names for XL
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace(
"updating %s of %s initializers",
len(blended.keys()),
len(base_model.graph.initializer),
)
# fix node names once
fixed_initializer_names = [ fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer fix_initializer_name(node.name) for node in base_model.graph.initializer
] ]
logger.trace("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
logger.trace("fixed node names: %s", fixed_node_names)
logger.debug(
"updating %s of %s initializers",
len(weights.keys()),
len(base_model.graph.initializer),
)
unmatched_keys = [] unmatched_keys = []
for base_key, weights in blended.items(): for base_key, weights in weights.items():
conv_key = base_key + "_Conv" conv_key = base_key + "_Conv"
gemm_key = base_key + "_Gemm" gemm_key = base_key + "_Gemm"
matmul_key = base_key + "_MatMul" matmul_key = base_key + "_MatMul"
logger.trace( logger.trace(
"key %s has conv: %s, matmul: %s", "key %s has conv: %s, gemm: %s, matmul: %s",
base_key, base_key,
conv_key in fixed_node_names, conv_key in fixed_node_names,
gemm_key in fixed_node_names,
matmul_key in fixed_node_names, matmul_key in fixed_node_names,
) )
@ -449,38 +540,9 @@ def blend_loras(
weight_node = base_model.graph.initializer[weight_idx] weight_node = base_model.graph.initializer[weight_idx]
logger.trace("found weight initializer: %s", weight_node.name) logger.trace("found weight initializer: %s", weight_node.name)
# blending # replace the previous node
onnx_weights = numpy_helper.to_array(weight_node) updated_node = blend_node_conv_gemm(weight_node, weights)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx] del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node) base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names: elif matmul_key in fixed_node_names:
@ -497,42 +559,15 @@ def blend_loras(
matmul_node = base_model.graph.initializer[matmul_idx] matmul_node = base_model.graph.initializer[matmul_idx]
logger.trace("found matmul initializer: %s", matmul_node.name) logger.trace("found matmul initializer: %s", matmul_node.name)
# blending # replace the previous node
onnx_weights = numpy_helper.to_array(matmul_node) updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if (
weights.shape != onnx_weights.shape
and t_weights.shape != onnx_weights.shape
):
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.debug(
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)
else: else:
unmatched_keys.append(base_key) unmatched_keys.append(base_key)
logger.debug( logger.trace(
"node counts: %s -> %s, %s -> %s", "node counts: %s -> %s, %s -> %s",
len(fixed_initializer_names), len(fixed_initializer_names),
len(base_model.graph.initializer), len(base_model.graph.initializer),
@ -541,10 +576,7 @@ def blend_loras(
) )
if len(unmatched_keys) > 0: if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys) logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
# if model_type == "unet":
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
return base_model return base_model
@ -568,63 +600,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
logger.debug("weights after interpolation: %s", output.shape) logger.debug("weights after interpolation: %s", output.shape)
return output return output
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -14,19 +14,155 @@ from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__) logger = getLogger(__name__)
def detect_embedding_format(loaded_embeds) -> str:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
return "concept"
elif "emb_params" in keys:
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
return "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
return "embeddings"
else:
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
return None
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds
token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight
if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(emb_params[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(trained_embeds[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
base_weights = numpy_helper.to_array(embedding_node)
weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.trace("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights
# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
@torch.no_grad() @torch.no_grad()
def blend_textual_inversions( def blend_textual_inversions(
server: ServerContext, server: ServerContext,
text_encoder: ModelProto, text_encoder: ModelProto,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]], embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]: ) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending # always load to CPU for blending
device = torch.device("cpu") device = torch.device("cpu")
dtype = np.float32 dtype = np.float32
embeds = {} embeds = {}
for name, weight, base_token, inversion_format in inversions: for name, weight, base_token, format in embeddings:
if base_token is None: if base_token is None:
logger.debug("no base token provided, using name: %s", name) logger.debug("no base token provided, using name: %s", name)
base_token = name base_token = name
@ -43,153 +179,28 @@ def blend_textual_inversions(
logger.warning("unable to load tensor") logger.warning("unable to load tensor")
continue continue
if inversion_format is None: if format is None:
keys: List[str] = list(loaded_embeds.keys()) format = detect_embedding_format(loaded_embeds)
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept"
elif "emb_params" in keys:
logger.debug(
"detected Textual Inversion parameter embeddings: %s", keys
)
inversion_format = "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue
if inversion_format == "concept": if format == "concept":
# separate token and the embeds blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
token = list(loaded_embeds.keys())[0] elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
layer = loaded_embeds[token].numpy().astype(dtype) elif format == "embeddings":
layer *= weight blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
elif inversion_format == "parameters":
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
sum_layer = np.zeros(emb_params[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
elif inversion_format == "embeddings":
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
sum_layer = np.zeros(trained_embeds[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
else: else:
raise ValueError(f"unknown Textual Inversion format: {inversion_format}") raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer # add the tokens to the tokenizer
logger.debug( num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
"found embeddings for %s tokens: %s", if num_added_tokens == 0:
len(embeds.keys()), raise ValueError(
list(embeds.keys()), "The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
) )
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens) logger.trace("added %s tokens", num_added_tokens)
# resize the token embeddings blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
base_weights = numpy_helper.to_array(embedding_node)
weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.trace("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights
# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
return (text_encoder, tokenizer) return (text_encoder, tokenizer)

View File

@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
class ConversionContext(ServerContext): class ConversionContext(ServerContext):
def __init__( def __init__(
self, self,
model_path: Optional[str] = None, model_path: str = ".",
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
half: bool = False, half: bool = False,
@ -69,7 +69,7 @@ class ConversionContext(ServerContext):
def from_environ(cls): def from_environ(cls):
context = super().from_environ() context = super().from_environ()
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True) context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True) context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True) context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True) context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET)) context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
def tuple_to_source(model: Union[ModelDict, LegacyModel]): def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *_rest = model
return { return {
"name": name, "name": name,
@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
def tuple_to_correction(model: Union[ModelDict, LegacyModel]): def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1 scale = rest.pop(0) if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False single_vae = rest.pop(0) if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1 scale = rest.pop(0) if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"] RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
_name, ext = path.splitext(name)
ext = ext.strip(".")
return (ext in exts, ext)
def source_format(model: Dict) -> Optional[str]: def source_format(model: Dict) -> Optional[str]:
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
return model["format"] return model["format"]
if "source" in model: if "source" in model:
_name, ext = path.splitext(model["source"]) valid, ext = check_ext(model["source"], MODEL_FORMATS)
if ext in MODEL_FORMATS: if valid:
return ext return ext
return None return None
@ -298,6 +305,7 @@ def onnx_export(
half=False, half=False,
external_data=False, external_data=False,
v2=False, v2=False,
op_block_list=None,
): ):
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@ -316,8 +324,7 @@ def onnx_export(
opset_version=opset, opset_version=opset,
) )
op_block_list = None if v2 and op_block_list is None:
if v2:
op_block_list = ["Attention", "MultiHeadAttention"] op_block_list = ["Attention", "MultiHeadAttention"]
if half: if half:

View File

@ -1,16 +1,15 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, List, Optional, Tuple from typing import Any, List, Literal, Optional, Tuple
from onnx import load_model from onnx import load_model
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline, ORTStableDiffusionXLPipeline,
) )
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL from ..constants import LATENT_FACTOR, ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
@ -24,6 +23,7 @@ from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from .version_safe_diffusers import ( from .version_safe_diffusers import (
DDIMScheduler, DDIMScheduler,
@ -38,6 +38,7 @@ from .version_safe_diffusers import (
KarrasVeScheduler, KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LCMScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
OnnxRuntimeModel, OnnxRuntimeModel,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
@ -58,6 +59,7 @@ available_pipelines = {
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline, # "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline, "lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
"panorama": OnnxStableDiffusionPanoramaPipeline, "panorama": OnnxStableDiffusionPanoramaPipeline,
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline, "pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
"txt2img-sdxl": ORTStableDiffusionXLPipeline, "txt2img-sdxl": ORTStableDiffusionXLPipeline,
"txt2img": OnnxStableDiffusionPipeline, "txt2img": OnnxStableDiffusionPipeline,
@ -77,12 +79,25 @@ pipeline_schedulers = {
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler, "k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler, "karras-ve": KarrasVeScheduler,
"lcm": LCMScheduler,
"lms-discrete": LMSDiscreteScheduler, "lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler, "pndm": PNDMScheduler,
"unipc-multi": UniPCMultistepScheduler, "unipc-multi": UniPCMultistepScheduler,
} }
def add_pipeline(name: str, pipeline: Any) -> bool:
global available_pipelines
if name in available_pipelines:
# TODO: decide if this should be allowed or not
logger.warning("cannot replace existing pipeline: %s", name)
return False
else:
available_pipelines[name] = pipeline
return True
def get_available_pipelines() -> List[str]: def get_available_pipelines() -> List[str]:
return list(available_pipelines.keys()) return list(available_pipelines.keys())
@ -99,16 +114,19 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
return None return None
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
def load_pipeline( def load_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
pipeline: str, pipeline: str,
device: DeviceParams, device: DeviceParams,
inversions: Optional[List[Tuple[str, float]]] = None, embeddings: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None,
model: Optional[str] = None, model: Optional[str] = None,
): ):
inversions = inversions or [] embeddings = embeddings or []
loras = loras or [] loras = loras or []
model = model or params.model model = model or params.model
@ -122,7 +140,7 @@ def load_pipeline(
device.device, device.device,
device.provider, device.provider,
control_key, control_key,
inversions, embeddings,
loras, loras,
) )
scheduler_key = (params.scheduler, model) scheduler_key = (params.scheduler, model)
@ -159,211 +177,376 @@ def load_pipeline(
run_gc([device]) run_gc([device])
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
components = { components = {
"scheduler": scheduler_type.from_pretrained( "scheduler": scheduler,
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
} }
# shared components # shared components
text_encoder = None
unet_type = "unet" unet_type = "unet"
# ControlNet component # ControlNet component
if params.is_control() and params.control is not None: if params.is_control() and params.control is not None:
cnet_path = path.join( logger.debug("loading ControlNet components")
server.model_path, "control", f"{params.control.name}.onnx" control_components = load_controlnet(server, device, params)
) components.update(control_components)
logger.debug("loading ControlNet weights from %s", cnet_path)
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
unet_type = "cnet" unet_type = "cnet"
# Textual Inversion blending # load various pipeline components
if inversions is not None and len(inversions) > 0: encoder_components = load_text_encoders(
logger.debug("blending Textual Inversions from %s", inversions) server, device, model, embeddings, loras, torch_dtype, params
inversion_names, inversion_weights = zip(*inversions) )
components.update(encoder_components)
inversion_models = [ unet_components = load_unet(server, device, model, loras, unet_type, params)
path.join(server.model_path, "inversion", name) components.update(unet_components)
for name in inversion_names
] vae_components = load_vae(server, device, model, params)
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) components.update(vae_components)
tokenizer = CLIPTokenizer.from_pretrained(
model, pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
subfolder="tokenizer",
torch_dtype=torch_dtype, if params.is_xl():
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class(
components["vae_decoder_session"],
components["text_encoder_session"],
components["unet_session"],
{
"force_zeros_for_empty_prompt": True,
"requires_aesthetics_score": False,
},
components["tokenizer"],
scheduler,
vae_encoder_session=components.get("vae_encoder_session", None),
text_encoder_2_session=components.get("text_encoder_2_session", None),
tokenizer_2=components.get("tokenizer_2", None),
) )
text_encoder, tokenizer = blend_textual_inversions( else:
if "vae" in components:
# upscale uses a single VAE
logger.debug(
"assembling SD pipeline for %s with single VAE",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae"],
components["text_encoder"],
components["tokenizer"],
components["unet"],
scheduler,
scheduler,
)
else:
logger.debug(
"assembling SD pipeline for %s with VAE codec",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae_encoder"],
components["vae_decoder"],
components["text_encoder"],
components["tokenizer"],
components["unet"],
scheduler,
None,
None,
requires_safety_checker=False,
)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
vae_model = getattr(pipe, vae)
if isinstance(vae_model, VAEWrapper):
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
# update panorama params
if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile,
unet_stride,
params.vae_tile,
params.vae_overlap,
)
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
run_gc([device])
return pipe
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path)
components = {}
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
return components
def load_text_encoders(
server: ServerContext,
device: DeviceParams,
model: str,
embeddings: Optional[List[Tuple[str, float]]],
loras: Optional[List[Tuple[str, float]]],
torch_dtype,
params: ImageParams,
):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
components = {
"tokenizer": tokenizer,
}
if params.is_xl():
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
tokenizer_2 = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer_2",
torch_dtype=torch_dtype,
)
components["tokenizer_2"] = tokenizer_2
# blend embeddings, if any
if embeddings is not None and len(embeddings) > 0:
embedding_names, embedding_weights = zip(*embeddings)
embedding_models = [
path.join(server.model_path, "inversion", name) for name in embedding_names
]
logger.debug(
"blending base model %s with embeddings from %s", model, embedding_models
)
# TODO: blend text_encoder_2 as well
text_encoder, tokenizer = blend_textual_inversions(
server,
text_encoder,
tokenizer,
list(
zip(
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
components["tokenizer"] = tokenizer
if params.is_xl():
text_encoder_2, tokenizer_2 = blend_textual_inversions(
server, server,
text_encoder, text_encoder_2,
tokenizer, tokenizer_2,
list( list(
zip( zip(
inversion_models, embedding_models,
inversion_weights, embedding_weights,
inversion_names, embedding_names,
[None] * len(inversion_models), [None] * len(embedding_models),
) )
), ),
) )
components["tokenizer_2"] = tokenizer_2
components["tokenizer"] = tokenizer # blend LoRAs, if any
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
# should be pretty small and should not need external data # blend and load text encoder
if loras is None or len(loras) == 0: text_encoder = blend_loras(
# TODO: handle XL encoders server,
components["text_encoder"] = OnnxRuntimeModel( text_encoder,
OnnxRuntimeModel.load_model( list(zip(lora_models, lora_weights)),
text_encoder.SerializeToString(), "text_encoder",
provider=device.ort_provider("text-encoder"), 1 if params.is_xl() else None,
sess_options=device.sess_options(), params.is_xl(),
) )
)
# LoRA blending if params.is_xl():
if loras is not None and len(loras) > 0: text_encoder_2 = blend_loras(
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)
# blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
text_encoder = blend_loras(
server, server,
text_encoder, text_encoder_2,
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"text_encoder", "text_encoder",
1 if params.is_xl() else None, 2,
params.is_xl(), params.is_xl(),
) )
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder # prepare external data for sessions
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
# encoder 2 only exists in XL
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
# session for te1
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
# session for te2
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
else:
# session for te
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
) )
text_encoder_names, text_encoder_values = zip(*text_encoder_data) )
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers( return components
list(text_encoder_names), list(text_encoder_values)
def load_unet(
server: ServerContext,
device: DeviceParams,
model: str,
loras: List[Tuple[str, float]],
unet_type: Literal["cnet", "unet"],
params: ImageParams,
):
components = {}
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
# LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load unet
unet = blend_loras(
server,
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
) )
)
if params.is_xl(): return components
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
if params.is_xl():
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras(
server,
text_encoder_2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl(),
)
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
text_encoder_2_session = InferenceSession( def load_vae(
text_encoder_2.SerializeToString(), _server: ServerContext, device: DeviceParams, model: str, params: ImageParams
providers=[device.ort_provider("text-encoder")], ):
sess_options=text_encoder_2_opts, # one or more VAE models need to be loaded
) vae = path.join(model, "vae", ONNX_MODEL)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2") vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
components["text_encoder_2_session"] = text_encoder_2_session vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
# blend and load unet components = {}
unet = path.join(model, unet_type, ONNX_MODEL) if not params.is_xl() and path.exists(vae):
blended_unet = blend_loras( logger.debug("loading VAE from %s", vae)
server, components["vae"] = OnnxRuntimeModel(
unet, OnnxRuntimeModel.load_model(
list(zip(lora_models, lora_weights)), vae,
"unet", provider=device.ort_provider("vae"),
xl=params.is_xl(), sess_options=device.sess_options(),
) )
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) )
unet_names, unet_values = zip(*unet_data) elif path.exists(vae_decoder) and path.exists(vae_encoder):
unet_opts = device.sess_options(cache=False) if params.is_xl():
unet_opts.add_external_initializers(list(unet_names), list(unet_values)) logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
if params.is_xl(): vae_decoder,
unet_session = InferenceSession( provider=device.ort_provider("vae"),
unet_model.SerializeToString(), sess_options=device.sess_options(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
)
)
# make sure a UNet has been loaded
if not params.is_xl() and "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider("unet"),
sess_options=device.sess_options(),
)
) )
components["vae_decoder_session"]._model_path = vae_decoder
# one or more VAE models need to be loaded logger.debug("loading VAE encoder from %s", vae_encoder)
vae = path.join(model, "vae", ONNX_MODEL) components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) vae_encoder,
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL) provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
if not params.is_xl() and path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
) )
elif ( components["vae_encoder_session"]._model_path = vae_encoder
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
): else:
logger.debug("loading VAE decoder from %s", vae_decoder) logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel( components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
@ -382,119 +565,44 @@ def load_pipeline(
) )
) )
# additional options for panorama pipeline return components
if params.is_panorama():
components["window"] = params.tiles // 8
components["stride"] = params.stride // 8
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
safety_checker=None,
torch_dtype=torch_dtype,
**components,
)
# make sure XL models are actually being used
# TODO: why is this needed?
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder.session == components["text_encoder_session"],
type(pipe.text_encoder),
)
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
if "text_encoder_2_session" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2.session == components["text_encoder_2_session"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = ORTModelTextEncoder(
text_encoder_2_session, text_encoder_2
)
if "unet_session" in components:
logger.info(
"unet matches: %s, %s",
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
pipe.unet = ORTModelUnet(unet_session, unet_model)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
if not params.is_xl():
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
# update panorama params
if params.is_panorama():
latent_window = params.tiles // 8
latent_stride = params.stride // 8
pipe.set_window_size(latent_window, latent_stride)
if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_window_size(latent_window, params.overlap)
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
run_gc([device])
return pipe
def optimize_pipeline( def optimize_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
) -> None: ) -> None:
if ( if server.has_optimization(
"diffusers-attention-slicing" in server.optimizations "diffusers-attention-slicing"
or "diffusers-attention-slicing-auto" in server.optimizations ) or server.has_optimization("diffusers-attention-slicing-auto"):
):
logger.debug("enabling auto attention slicing on SD pipeline") logger.debug("enabling auto attention slicing on SD pipeline")
try: try:
pipe.enable_attention_slicing(slice_size="auto") pipe.enable_attention_slicing(slice_size="auto")
except Exception as e: except Exception as e:
logger.warning("error while enabling auto attention slicing: %s", e) logger.warning("error while enabling auto attention slicing: %s", e)
if "diffusers-attention-slicing-max" in server.optimizations: if server.has_optimization("diffusers-attention-slicing-max"):
logger.debug("enabling max attention slicing on SD pipeline") logger.debug("enabling max attention slicing on SD pipeline")
try: try:
pipe.enable_attention_slicing(slice_size="max") pipe.enable_attention_slicing(slice_size="max")
except Exception as e: except Exception as e:
logger.warning("error while enabling max attention slicing: %s", e) logger.warning("error while enabling max attention slicing: %s", e)
if "diffusers-vae-slicing" in server.optimizations: if server.has_optimization("diffusers-vae-slicing"):
logger.debug("enabling VAE slicing on SD pipeline") logger.debug("enabling VAE slicing on SD pipeline")
try: try:
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
except Exception as e: except Exception as e:
logger.warning("error while enabling VAE slicing: %s", e) logger.warning("error while enabling VAE slicing: %s", e)
if "diffusers-cpu-offload-sequential" in server.optimizations: if server.has_optimization("diffusers-cpu-offload-sequential"):
logger.debug("enabling sequential CPU offload on SD pipeline") logger.debug("enabling sequential CPU offload on SD pipeline")
try: try:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
except Exception as e: except Exception as e:
logger.warning("error while enabling sequential CPU offload: %s", e) logger.warning("error while enabling sequential CPU offload: %s", e)
elif "diffusers-cpu-offload-model" in server.optimizations: elif server.has_optimization("diffusers-cpu-offload-model"):
# TODO: check for accelerate # TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline") logger.debug("enabling model CPU offload on SD pipeline")
try: try:
@ -502,7 +610,7 @@ def optimize_pipeline(
except Exception as e: except Exception as e:
logger.warning("error while enabling model CPU offload: %s", e) logger.warning("error while enabling model CPU offload: %s", e)
if "diffusers-memory-efficient-attention" in server.optimizations: if server.has_optimization("diffusers-memory-efficient-attention"):
# TODO: check for xformers # TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline") logger.debug("enabling memory efficient attention for SD pipeline")
try: try:
@ -514,17 +622,17 @@ def optimize_pipeline(
def patch_pipeline( def patch_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
pipe_type: str,
pipeline: Any, pipeline: Any,
params: ImageParams, params: ImageParams,
) -> None: ) -> None:
logger.debug("patching SD pipeline") logger.debug("patching SD pipeline")
if pipe_type != "lpw": if not params.is_lpw() and not params.is_xl():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
original_unet = pipe.unet original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet) pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
logger.debug("patched UNet with wrapper")
if hasattr(pipe, "vae_decoder"): if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder original_decoder = pipe.vae_decoder
@ -532,18 +640,21 @@ def patch_pipeline(
server, server,
original_decoder, original_decoder,
decoder=True, decoder=True,
window=params.tiles, window=params.unet_tile,
overlap=params.overlap, overlap=params.vae_overlap,
) )
logger.debug("patched VAE decoder with wrapper")
if hasattr(pipe, "vae_encoder"):
original_encoder = pipe.vae_encoder original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper( pipe.vae_encoder = VAEWrapper(
server, server,
original_encoder, original_encoder,
decoder=False, decoder=False,
window=params.tiles, window=params.unet_tile,
overlap=params.overlap, overlap=params.vae_overlap,
) )
elif hasattr(pipe, "vae"): logger.debug("patched VAE encoder with wrapper")
pass # TODO: current wrapper does not work with upscaling VAE
else: if hasattr(pipe, "vae"):
logger.debug("no VAE found to patch") logger.warning("not patching single VAE, tiled VAE may not work")

View File

@ -14,20 +14,23 @@ class UNetWrapper(object):
prompt_index: int = 0 prompt_index: int = 0
server: ServerContext server: ServerContext
wrapped: OnnxRuntimeModel wrapped: OnnxRuntimeModel
xl: bool
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
wrapped: OnnxRuntimeModel, wrapped: OnnxRuntimeModel,
xl: bool,
): ):
self.server = server self.server = server
self.wrapped = wrapped self.wrapped = wrapped
self.xl = xl
def __call__( def __call__(
self, self,
sample: np.ndarray = None, sample: Optional[np.ndarray] = None,
timestep: np.ndarray = None, timestep: Optional[np.ndarray] = None,
encoder_hidden_states: np.ndarray = None, encoder_hidden_states: Optional[np.ndarray] = None,
**kwargs, **kwargs,
): ):
logger.trace( logger.trace(
@ -43,13 +46,21 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index] encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1 self.prompt_index += 1
if sample.dtype != timestep.dtype: if self.xl:
logger.trace("converting UNet sample to timestep dtype") if sample.dtype != encoder_hidden_states.dtype:
sample = sample.astype(timestep.dtype) logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
else:
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype: if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype") logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
return self.wrapped( return self.wrapped(
sample=sample, sample=sample,

View File

@ -12,8 +12,6 @@ from ...server import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
LATENT_CHANNELS = 4
class VAEWrapper(object): class VAEWrapper(object):
def __init__( def __init__(
@ -39,11 +37,17 @@ class VAEWrapper(object):
self.tile_overlap_factor = overlap self.tile_overlap_factor = overlap
def __call__(self, latent_sample=None, sample=None, **kwargs): def __call__(self, latent_sample=None, sample=None, **kwargs):
model = (
self.wrapped.model
if hasattr(self.wrapped, "model")
else self.wrapped.session
)
# set timestep dtype to input type # set timestep dtype to input type
sample_dtype = next( sample_dtype = next(
( (
input.type input.type
for input in self.wrapped.model.get_inputs() for input in model.get_inputs()
if input.name == "sample" or input.name == "latent_sample" if input.name == "sample" or input.name == "latent_sample"
), ),
"tensor(float)", "tensor(float)",

View File

@ -13,8 +13,8 @@ import numpy as np
import PIL import PIL
import torch import torch
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, logging

View File

@ -13,25 +13,36 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Callable, List, Optional, Union from math import ceil
from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...chain.tile import make_tile_mask
from ...constants import LATENT_CHANNELS, LATENT_FACTOR
from ...params import Size
from ..utils import (
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# inpaint constants # inpaint constants
NUM_UNET_INPUT_CHANNELS = 9 NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4
DEFAULT_WINDOW = 32 DEFAULT_WINDOW = 32
DEFAULT_STRIDE = 8 DEFAULT_STRIDE = 8
@ -346,13 +357,16 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
def get_views(self, panorama_height, panorama_width, window_size, stride): def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8 panorama_height /= 8
panorama_width /= 8 panorama_width /= 8
num_blocks_height = abs((panorama_height - window_size) // stride) + 1 num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = abs((panorama_width - window_size) // stride) + 1 num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width) total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug( logger.debug(
"panorama generated %s views, %s by %s blocks", "panorama generated %s views, %s by %s blocks",
@ -369,7 +383,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
w_end = w_start + window_size w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end)) views.append((h_start, h_end, w_start, w_end))
return views return (views, (h_end * 8, w_end * 8))
@torch.no_grad() @torch.no_grad()
def text2img( def text2img(
@ -479,6 +493,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
num_images_per_prompt, num_images_per_prompt,
@ -488,9 +504,30 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
region_prompt_embeds = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
region_embeds.append(region_prompt_embeds)
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) latents_shape = (
batch_size * num_images_per_prompt,
LATENT_CHANNELS,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape: elif latents.shape != latents_shape:
@ -525,11 +562,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions # panorama additions
views = self.get_views(height, width, self.window, self.stride) views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents) logger.trace("panorama resized latents to %s", resize)
value = np.zeros_like(latents)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0) count.fill(0)
value.fill(0) value.fill(0)
@ -576,13 +624,115 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = (
region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 100.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents)
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
@ -828,9 +978,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions # panorama additions
views = self.get_views(height, width, self.window, self.stride) views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents) logger.trace("panorama resized latents to %s", resize)
value = np.zeros_like(latents)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
count.fill(0) count.fill(0)
@ -886,6 +1046,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
@ -1053,12 +1218,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
num_channels_latents = NUM_LATENT_CHANNELS num_channels_latents = LATENT_CHANNELS
latents_shape = ( latents_shape = (
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height // 8, height // LATENT_FACTOR,
width // 8, width // LATENT_FACTOR,
) )
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
if latents is None: if latents is None:
@ -1136,9 +1301,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions # panorama additions
views = self.get_views(height, width, self.window, self.stride) views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents) logger.trace("panorama resized latents to %s", resize)
value = np.zeros_like(latents)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
count.fill(0) count.fill(0)
@ -1201,6 +1376,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1

View File

@ -0,0 +1,955 @@
import inspect
import logging
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
StableDiffusionXLImg2ImgPipelineMixin,
)
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from ...chain.tile import make_tile_mask
from ...constants import LATENT_FACTOR
from ...params import Size
from ..utils import (
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.getLogger(__name__)
DEFAULT_WINDOW = 64
DEFAULT_STRIDE = 16
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
def __init__(
self,
*args,
window: int = DEFAULT_WINDOW,
stride: int = DEFAULT_STRIDE,
**kwargs,
):
super().__init__(self, *args, **kwargs)
self.window = window
self.stride = stride
def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride
def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug(
"panorama generated %s views, %s by %s blocks",
total_num_blocks,
num_blocks_height,
num_blocks_width,
)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return (views, (h_end * 8, w_end * 8))
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_img2img(
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
):
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae_encoder(sample=image)[
0
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
if (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] == 0
):
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = np.concatenate(
[init_latents] * additional_image_per_prompt, axis=0
)
elif (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] != 0
):
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = np.concatenate([init_latents], axis=0)
# add noise to latents using the timesteps
noise = generator.randn(*init_latents.shape).astype(dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents),
torch.from_numpy(noise),
torch.from_numpy(timestep),
)
return init_latents.numpy()
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_text2img(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = generator.randn(*shape).astype(dtype)
elif latents.shape != shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
return latents
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
extra_step_kwargs = {}
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
return extra_step_kwargs
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def text2img(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`Optional[int]`, defaults to None):
The height in pixels of the generated image.
width (`Optional[int]`, defaults to None):
The width in pixels of the generated image.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`Optional[Union[str, list]]`):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
eta (`float`, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
1.0,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
(
region_prompt_embeds,
region_negative_prompt_embeds,
region_pooled_prompt_embeds,
region_negative_pooled_prompt_embeds,
) = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
if do_classifier_free_guidance:
region_prompt_embeds = np.concatenate(
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
)
add_region_embeds.append(
np.concatenate(
(
region_negative_pooled_prompt_embeds,
region_pooled_prompt_embeds,
),
axis=0,
)
)
region_embeds.append(region_prompt_embeds)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
latents = self.prepare_latents_text2img(
batch_size * num_images_per_prompt,
self.unet.config.get("in_channels", 4),
height,
width,
prompt_embeds.dtype,
generator,
latents,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = (original_size + crops_coords_top_left + target_size,)
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# Adapted from diffusers to extend it for other runtimes than ORT
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_embeds=add_text_embeds,
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = (
region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 100.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents
else:
latents = np.clip(latents, -4, +4)
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def img2img(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.3,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`Union[np.ndarray, PIL.Image.Image]`):
`Image`, or tensor representing an image batch which will be upscaled.
strength (`float`, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`Optional[Union[str, list]]`):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
eta (`float`, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 1. Define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 2. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
# 3. Preprocess image
processor = VaeImageProcessor()
image = processor.preprocess(image).cpu().numpy()
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength
)
latent_timestep = np.repeat(
timesteps[:1], batch_size * num_images_per_prompt, axis=0
)
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype)
# 5. Prepare latent variables
latents = self.prepare_latents_img2img(
image,
latent_timestep,
batch_size,
num_images_per_prompt,
latents_dtype,
generator,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = {}
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
height, width = latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 8. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_embeds=add_text_embeds,
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
def __call__(
self,
*args,
**kwargs,
):
if "image" in kwargs or (
len(args) > 1
and (
isinstance(args[1], np.ndarray) or isinstance(args[1], PIL.Image.Image)
)
):
logger.debug("running img2img panorama XL pipeline")
return self.img2img(*args, **kwargs)
else:
logger.debug("running txt2img panorama XL pipeline")
return self.text2img(*args, **kwargs)
class ORTStableDiffusionXLPanoramaPipeline(
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
):
def __call__(self, *args, **kwargs):
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)

View File

@ -32,7 +32,7 @@ except ImportError:
} }
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import ( from diffusers.schedulers import (
DDIMScheduler, DDIMScheduler,

View File

@ -1,63 +1,25 @@
###
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
# See also: https://github.com/huggingface/diffusers/pull/2158
###
from logging import getLogger from logging import getLogger
from typing import Any, Callable, List, Optional, Union from typing import Any, List
import numpy as np from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
import PIL from diffusers.pipelines.stable_diffusion import (
import torch OnnxStableDiffusionUpscalePipeline as BasePipeline,
from diffusers.pipeline_utils import ImagePipelineOutput )
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
from diffusers.schedulers import DDPMScheduler from diffusers.schedulers import DDPMScheduler
logger = getLogger(__name__) logger = getLogger(__name__)
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
ORT_TO_PT_TYPE = {
"float16": torch.float16,
"float32": torch.float32,
}
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class FakeConfig: class FakeConfig:
block_out_channels: List[int]
scaling_factor: float scaling_factor: float
def __init__(self) -> None: def __init__(self) -> None:
self.block_out_channels = [128, 256, 512]
self.scaling_factor = 0.08333 self.scaling_factor = 0.08333
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): class OnnxStableDiffusionUpscalePipeline(BasePipeline):
def __init__( def __init__(
self, self,
vae: OnnxRuntimeModel, vae: OnnxRuntimeModel,
@ -80,260 +42,3 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler, scheduler,
max_noise_level=max_noise_level, max_noise_level=max_noise_level,
) )
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
# device, device only needed for Torch pipelines
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
# 4. Preprocess image
image = preprocess(image)
image = image.cpu()
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(
image.shape, generator=generator, device=device, dtype=latents_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
noise_level = np.concatenate([noise_level] * image.shape[0])
# 6. Prepare latent variables
height, width = image.shape[2:]
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
NUM_LATENT_CHANNELS,
height,
width,
latents_dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
f" `num_channels_image`: {num_channels_image} "
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timestep_dtype = next(
(
input.type
for input in self.unet.model.get_inputs()
if input.name == "timestep"
),
"tensor(float)",
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
timestep = np.array([t], dtype=timestep_dtype)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level.astype(np.int64),
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# no positional arguments to text_encoder
text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device),
)
text_embeddings = text_embeddings[0]
bs_embed, seq_len, _ = text_embeddings.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
text_embeddings = text_embeddings.reshape(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device),
)
uncond_embeddings = uncond_embeddings[0]
seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.reshape(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
return text_embeddings

View File

@ -4,15 +4,16 @@ from typing import Any, List, Optional
from PIL import Image, ImageOps from PIL import Image, ImageOps
from onnx_web.chain.highres import stage_highres
from ..chain import ( from ..chain import (
BlendDenoiseStage,
BlendImg2ImgStage, BlendImg2ImgStage,
BlendMaskStage, BlendMaskStage,
ChainPipeline, ChainPipeline,
SourceTxt2ImgStage, SourceTxt2ImgStage,
UpscaleOutpaintStage, UpscaleOutpaintStage,
) )
from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image from ..output import save_image
@ -33,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
def get_base_tile(params: ImageParams, size: Size) -> int:
if params.is_panorama():
tile = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile)
return tile
return params.unet_tile
def get_highres_tile(
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
) -> int:
if params.is_panorama() and server.has_feature("panorama-highres"):
return tile * highres.scale
return params.unet_tile
def run_txt2img_pipeline( def run_txt2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
@ -43,10 +62,7 @@ def run_txt2img_pipeline(
highres: HighresParams, highres: HighresParams,
) -> None: ) -> None:
# if using panorama, the pipeline will tile itself (views) # if using panorama, the pipeline will tile itself (views)
if params.is_panorama() or params.is_xl(): tile_size = get_base_tile(params, size)
tile_size = max(params.tiles, size.width, size.height)
else:
tile_size = params.tiles
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
@ -57,15 +73,21 @@ def run_txt2img_pipeline(
), ),
size=size, size=size,
prompt_index=0, prompt_index=0,
overlap=params.overlap, overlap=params.vae_overlap,
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
stage = StageParams(tile_size=params.tiles) highres_size = get_highres_tile(server, params, highres, tile_size)
if params.is_panorama():
chain.stage(
BlendDenoiseStage(),
StageParams(tile_size=highres_size),
)
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
params, params,
chain=chain, chain=chain,
upscale=first_upscale, upscale=first_upscale,
@ -73,7 +95,7 @@ def run_txt2img_pipeline(
# apply highres # apply highres
stage_highres( stage_highres(
stage, StageParams(outscale=highres.scale, tile_size=highres_size),
params, params,
highres, highres,
upscale, upscale,
@ -83,7 +105,7 @@ def run_txt2img_pipeline(
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
params, params,
chain=chain, chain=chain,
upscale=after_upscale, upscale=after_upscale,
@ -92,11 +114,14 @@ def run_txt2img_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run(worker, server, params, [], callback=progress, latents=latents) images = chain.run(
worker, server, params, StageResult.empty(), callback=progress, latents=latents
)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image( dest = save_image(
server, server,
output, output,
@ -136,23 +161,26 @@ def run_img2img_pipeline(
source = f(server, source) source = f(server, source)
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
tile_size = get_base_tile(params, Size(*source.size))
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(
tile_size=params.tiles,
)
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, StageParams(
tile_size=tile_size,
),
prompt_index=0, prompt_index=0,
strength=strength, strength=strength,
overlap=params.overlap, overlap=params.vae_overlap,
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(
outscale=first_upscale.outscale,
tile_size=tile_size,
),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
@ -162,13 +190,16 @@ def run_img2img_pipeline(
for _i in range(params.loopback): for _i in range(params.loopback):
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, StageParams(
tile_size=tile_size,
),
strength=strength, strength=strength,
) )
# highres, if selected # highres, if selected
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(tile_size=highres_size, outscale=highres.scale),
params, params,
highres, highres,
upscale, upscale,
@ -178,7 +209,7 @@ def run_img2img_pipeline(
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(tile_size=tile_size, outscale=after_upscale.scale),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -186,7 +217,9 @@ def run_img2img_pipeline(
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress) images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
images.append(source) images.append(source)
@ -235,7 +268,7 @@ def run_inpaint_pipeline(
full_res_inpaint_padding: float, full_res_inpaint_padding: float,
) -> None: ) -> None:
logger.debug("building inpaint pipeline") logger.debug("building inpaint pipeline")
tile_size = params.tiles tile_size = get_base_tile(params, size)
if mask is None: if mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
@ -264,8 +297,12 @@ def run_inpaint_pipeline(
logger.debug("border zero: %s", border.isZero()) logger.debug("border zero: %s", border.isZero())
full_res_inpaint = full_res_inpaint and border.isZero() full_res_inpaint = full_res_inpaint and border.isZero()
if full_res_inpaint: if full_res_inpaint:
mask_left, mask_top, mask_right, mask_bottom = mask.getbbox() bbox = mask.getbbox()
logger.debug("mask bbox: %s", mask.getbbox()) if bbox is None:
bbox = (0, 0, source.width, source.height)
logger.debug("mask bounding box: %s", bbox)
mask_left, mask_top, mask_right, mask_bottom = bbox
mask_width = mask_right - mask_left mask_width = mask_right - mask_left
mask_height = mask_bottom - mask_top mask_height = mask_bottom - mask_top
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even) # ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
@ -322,16 +359,15 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
chain.stage( chain.stage(
UpscaleOutpaintStage(), UpscaleOutpaintStage(),
stage, StageParams(tile_order=tile_order, tile_size=tile_size),
border=border, border=border,
mask=mask, mask=mask,
fill_color=fill_color, fill_color=fill_color,
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
overlap=params.overlap, overlap=params.vae_overlap,
prompt_index=0, prompt_index=0,
) )
@ -339,15 +375,16 @@ def run_inpaint_pipeline(
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
) )
# apply highres # apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(outscale=highres.scale, tile_size=highres_size),
params, params,
highres, highres,
upscale, upscale,
@ -357,7 +394,7 @@ def run_inpaint_pipeline(
# apply upscaling and correction # apply upscaling and correction
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=after_upscale.outscale),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -366,7 +403,14 @@ def run_inpaint_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress, latents=latents) images = chain.run(
worker,
server,
params,
StageResult(images=[source]),
callback=progress,
latents=latents,
)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
@ -409,21 +453,22 @@ def run_upscale_pipeline(
) -> None: ) -> None:
# set up the chain pipeline, no base stage for upscaling # set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_size=params.tiles) tile_size = get_base_tile(params, size)
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
) )
# apply highres # apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(outscale=highres.scale, tile_size=highres_size),
params, params,
highres, highres,
upscale, upscale,
@ -433,7 +478,7 @@ def run_upscale_pipeline(
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -441,7 +486,9 @@ def run_upscale_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress) images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
@ -478,12 +525,18 @@ def run_blend_pipeline(
) -> None: ) -> None:
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() tile_size = get_base_tile(params, size)
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
chain.stage(
BlendMaskStage(),
StageParams(tile_size=tile_size),
stage_source=sources[1],
stage_mask=mask,
)
# apply upscaling and correction # apply upscaling and correction
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=upscale.outscale),
params, params,
upscale=upscale, upscale=upscale,
chain=chain, chain=chain,
@ -491,7 +544,9 @@ def run_blend_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, sources, callback=progress) images = chain.run(
worker, server, params, StageResult(images=sources), callback=progress
)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale) dest = save_image(server, output, image, params, size, upscale=upscale)

View File

@ -3,23 +3,27 @@ from copy import deepcopy
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import Pattern, compile from re import Pattern, compile
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
from ..params import ImageParams, Size from ..params import ImageParams, Size
logger = getLogger(__name__) logger = getLogger(__name__)
LATENT_CHANNELS = 4
LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77 MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>") CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>") INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
)
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
@ -84,8 +88,8 @@ def expand_prompt(
negative_prompt: Optional[str] = None, negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: Optional[int] = 0, skip_clip_states: int = 0,
) -> "np.NDArray": ) -> np.ndarray:
# self provides: # self provides:
# tokenizer: CLIPTokenizer # tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel # encoder: OnnxRuntimeModel
@ -140,6 +144,7 @@ def expand_prompt(
last_state, _pooled_output, *hidden_states = text_result last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0: if skip_clip_states > 0:
# TODO: why is this normalized?
layer_norm = torch.nn.LayerNorm(last_state.shape[2]) layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm( norm_state = layer_norm(
torch.from_numpy( torch.from_numpy(
@ -219,20 +224,25 @@ def expand_prompt(
return prompt_embeds return prompt_embeds
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
name, weight = group
return (name, float(weight))
def get_tokens_from_prompt( def get_tokens_from_prompt(
prompt: str, pattern: Pattern prompt: str,
pattern: Pattern,
parser=parse_float_group,
) -> Tuple[str, List[Tuple[str, float]]]: ) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt remaining_prompt = prompt
tokens = [] tokens = []
next_match = pattern.search(remaining_prompt) next_match = pattern.search(remaining_prompt)
while next_match is not None: while next_match is not None:
logger.debug("found token in prompt: %s", next_match) logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups() group = next_match.groups()
tokens.append((name, float(weight))) tokens.append(parser(group))
# remove this match and look for another # remove this match and look for another
remaining_prompt = ( remaining_prompt = (
remaining_prompt[: next_match.start()] remaining_prompt[: next_match.start()]
@ -251,6 +261,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]
return get_tokens_from_prompt(prompt, INVERSION_TOKEN) return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
def random_seed(generator=None) -> int:
if generator is None:
generator = np.random
return generator.randint(np.iinfo(np.int32).max)
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
""" """
From https://www.travelneil.com/stable-diffusion-updates.html. From https://www.travelneil.com/stable-diffusion-updates.html.
@ -266,6 +283,25 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
return image_latents return image_latents
def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
sigma: float = 1.0,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents * np.float64(sigma)
def resize_latent_shape(
latents: np.ndarray,
size: Tuple[int, int],
) -> Tuple[int, int, int, int]:
return (latents.shape[0], latents.shape[1], *size)
def get_tile_latents( def get_tile_latents(
full_latents: np.ndarray, full_latents: np.ndarray,
seed: int, seed: int,
@ -290,14 +326,8 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt] tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape != full_latents.shape and ( if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
tile_latents.shape[2] < t or tile_latents.shape[3] < t tile_latents = expand_latents(tile_latents, seed, size)
):
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
] = tile_latents
tile_latents = extra_latents
return tile_latents return tile_latents
@ -369,12 +399,15 @@ def encode_prompt(
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]: ) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [ return [
pipe._encode_prompt( pipe._encode_prompt(
prompt, remove_tokens(prompt),
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=neg_prompt, negative_prompt=remove_tokens(neg_prompt),
) )
for prompt, neg_prompt in prompt_pairs for prompt, neg_prompt in prompt_pairs
] ]
@ -444,3 +477,71 @@ def slice_prompt(prompt: str, slice: int) -> str:
return parts[min(slice, len(parts) - 1)] return parts[min(slice, len(parts) - 1)]
else: else:
return prompt return prompt
Region = Tuple[
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
]
def parse_region_group(group: Tuple[str, ...]) -> Region:
top, left, bottom, right, weight, feather, prompt = group
# break down the feather section
feather_radius, *feather_edges = feather.split("_")
if len(feather_edges) == 0:
feather_edges = "TLBR"
else:
feather_edges = "".join(feather_edges)
return (
int(top),
int(left),
int(bottom),
int(right),
float(weight),
(
float(feather_radius),
(
"T" in feather_edges,
"L" in feather_edges,
"B" in feather_edges,
"R" in feather_edges,
),
),
prompt,
)
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
Reseed = Tuple[int, int, int, int, int]
def parse_reseed_group(group) -> Region:
top, left, bottom, right, seed = group
return (
int(top),
int(left),
int(bottom),
int(right),
int(seed),
)
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
def skip_group(group) -> Any:
return group
def remove_tokens(prompt: Optional[str]) -> Optional[str]:
if prompt is None:
return prompt
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group)
return remainder

View File

@ -12,6 +12,11 @@ try:
except ImportError: except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
try:
from diffusers import LCMScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as LCMScheduler
try: try:
from diffusers import UniPCMultistepScheduler from diffusers import UniPCMultistepScheduler
except ImportError: except ImportError:

View File

@ -8,7 +8,7 @@ def mask_filter_none(
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(mask.mode, (width, height), fill)
noise.paste(mask, origin) noise.paste(mask, origin)
return noise return noise

View File

@ -17,21 +17,21 @@ def noise_source_fill_edge(
""" """
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(source.mode, (width, height), fill)
noise.paste(source, origin) noise.paste(source, origin)
return noise return noise
def noise_source_fill_mask( def noise_source_fill_mask(
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
) -> Image.Image: ) -> Image.Image:
""" """
Fill the whole canvas, no source or noise. Fill the whole canvas, no source or noise.
""" """
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(source.mode, (width, height), fill)
return noise return noise
@ -52,7 +52,7 @@ def noise_source_gaussian(
def noise_source_uniform( def noise_source_uniform(
_source: Image.Image, dims: Point, _origin: Point, **kw source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -61,6 +61,7 @@ def noise_source_uniform(
noise_g = random.uniform(0, 256, size=size) noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -68,11 +69,11 @@ def noise_source_uniform(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise return noise.convert(source.mode)
def noise_source_normal( def noise_source_normal(
_source: Image.Image, dims: Point, _origin: Point, **kw source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -81,6 +82,7 @@ def noise_source_normal(
noise_g = random.normal(128, 32, size=size) noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -88,13 +90,13 @@ def noise_source_normal(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise return noise.convert(source.mode)
def noise_source_histogram( def noise_source_histogram(
source: Image.Image, dims: Point, _origin: Point, **kw source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
r, g, b = source.split() r, g, b, *_a = source.split()
width, height = dims width, height = dims
size = width * height size = width * height
@ -112,6 +114,7 @@ def noise_source_histogram(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size 256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
) )
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -119,4 +122,4 @@ def noise_source_histogram(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
return noise return noise.convert(source.mode)

View File

@ -47,7 +47,7 @@ def source_filter_noise(
source: Image.Image, source: Image.Image,
strength: float = 0.5, strength: float = 0.5,
): ):
noise = noise_source_histogram(source, source.size) noise = noise_source_histogram(source, source.size, (0, 0))
return ImageChops.blend(source, noise, strength) return ImageChops.blend(source, noise, strength)

View File

@ -1,3 +1,5 @@
from typing import Tuple
from PIL import Image, ImageChops from PIL import Image, ImageChops
from ..params import Border, Size from ..params import Border, Size
@ -13,12 +15,12 @@ def expand_image(
fill="white", fill="white",
noise_source=noise_source_histogram, noise_source=noise_source_histogram,
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
size = Size(*source.size).add_border(expand) size = Size(*source.size).add_border(expand)
size = tuple(size) size = tuple(size)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)
full_source = Image.new("RGB", size, fill) full_source = Image.new(source.mode, size, fill)
full_source.paste(source, origin) full_source.paste(source, origin)
# new mask pixels need to be filled with white so they will be replaced # new mask pixels need to be filled with white so they will be replaced

View File

@ -23,6 +23,7 @@ from .server.load import (
load_platforms, load_platforms,
load_wildcards, load_wildcards,
) )
from .server.plugin import load_plugins, register_plugins
from .server.static import register_static_routes from .server.static import register_static_routes
from .server.utils import check_paths from .server.utils import check_paths
from .utils import is_debug from .utils import is_debug
@ -43,15 +44,32 @@ def main():
server = ServerContext.from_environ() server = ServerContext.from_environ()
apply_patches(server) apply_patches(server)
check_paths(server) check_paths(server)
# debug options
if server.debug:
import debugpy
debugpy.listen(5678)
logger.warning("waiting for debugger")
debugpy.wait_for_client()
gc.set_debug(gc.DEBUG_STATS)
# register plugins
exports = load_plugins(server)
success = register_plugins(exports)
if success:
logger.info("all plugins loaded successfully")
else:
logger.warning("error loading plugins")
# load additional resources
load_extras(server) load_extras(server)
load_models(server) load_models(server)
load_params(server) load_params(server)
load_platforms(server) load_platforms(server)
load_wildcards(server) load_wildcards(server)
if is_debug(): # misc server options
gc.set_debug(gc.DEBUG_STATS)
if not server.show_progress: if not server.show_progress:
disable_progress_bar() disable_progress_bar()
disable_progress_bars() disable_progress_bars()

View File

@ -1,18 +1,21 @@
from typing import Literal from typing import List, Literal
NetworkType = Literal["inversion", "lora"] NetworkType = Literal["control", "inversion", "lora"]
class NetworkModel: class NetworkModel:
name: str name: str
tokens: List[str]
type: NetworkType type: NetworkType
def __init__(self, name: str, type: NetworkType) -> None: def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
self.name = name self.name = name
self.tokens = tokens or []
self.type = type self.type = type
def tojson(self): def tojson(self):
return { return {
"name": self.name, "name": self.name,
"tokens": self.tokens,
"type": self.type, "type": self.type,
} }

View File

@ -57,7 +57,7 @@ def json_params(
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
parent: Dict = None, parent: Optional[Dict] = None,
) -> Any: ) -> Any:
json = { json = {
"input_size": size.tojson(), "input_size": size.tojson(),
@ -158,6 +158,7 @@ def make_output_name(
size: Size, size: Size,
extras: Optional[List[Optional[Param]]] = None, extras: Optional[List[Optional[Param]]] = None,
count: Optional[int] = None, count: Optional[int] = None,
offset: int = 0,
) -> List[str]: ) -> List[str]:
count = count or params.batch count = count or params.batch
now = int(time()) now = int(time())
@ -183,7 +184,7 @@ def make_output_name(
return [ return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(count) for i in range(offset, count + offset)
] ]

View File

@ -14,7 +14,7 @@ Point = Tuple[int, int]
class SizeChart(IntEnum): class SizeChart(IntEnum):
unlimited = 0 micro = 64
mini = 128 # small tile for very expensive models mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting half = 256 # half tile for outpainting
auto = 512 # auto tile size auto = 512 # auto tile size
@ -25,6 +25,7 @@ class SizeChart(IntEnum):
hd16k = 2**14 hd16k = 2**14
hd32k = 2**15 hd32k = 2**15
hd64k = 2**16 hd64k = 2**16
max = 2**32 # should be a reasonable upper limit for now
class TileOrder: class TileOrder:
@ -140,7 +141,7 @@ class DeviceParams:
if self.options is None: if self.options is None:
return self.provider return self.provider
else: else:
return self.provider # (self.provider, self.options) return (self.provider, self.options)
def sess_options(self, cache=True) -> SessionOptions: def sess_options(self, cache=True) -> SessionOptions:
if cache and self.sess_options_cache is not None: if cache and self.sess_options_cache is not None:
@ -201,11 +202,14 @@ class ImageParams:
batch: int batch: int
control: Optional[NetworkModel] control: Optional[NetworkModel]
input_prompt: str input_prompt: str
input_negative_prompt: str input_negative_prompt: Optional[str]
loopback: int loopback: int
tiled_vae: bool tiled_vae: bool
tiles: int unet_tile: int
overlap: float unet_overlap: float
vae_tile: int
vae_overlap: float
denoise: int
def __init__( def __init__(
self, self,
@ -224,9 +228,11 @@ class ImageParams:
input_negative_prompt: Optional[str] = None, input_negative_prompt: Optional[str] = None,
loopback: int = 0, loopback: int = 0,
tiled_vae: bool = False, tiled_vae: bool = False,
tiles: int = 512, unet_overlap: float = 0.25,
overlap: float = 0.25, unet_tile: int = 512,
stride: int = 64, vae_overlap: float = 0.25,
vae_tile: int = 512,
denoise: int = 3,
) -> None: ) -> None:
self.model = model self.model = model
self.pipeline = pipeline self.pipeline = pipeline
@ -243,14 +249,16 @@ class ImageParams:
self.input_negative_prompt = input_negative_prompt or negative_prompt self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback self.loopback = loopback
self.tiled_vae = tiled_vae self.tiled_vae = tiled_vae
self.tiles = tiles self.unet_overlap = unet_overlap
self.overlap = overlap self.unet_tile = unet_tile
self.stride = stride self.vae_overlap = vae_overlap
self.vae_tile = vae_tile
self.denoise = denoise
def do_cfg(self): def do_cfg(self):
return self.cfg > 1.0 return self.cfg > 1.0
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str: def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
pipeline = pipeline or self.pipeline pipeline = pipeline or self.pipeline
# if the correct pipeline was already requested, simply use that # if the correct pipeline was already requested, simply use that
@ -259,7 +267,14 @@ class ImageParams:
# otherwise, check for additional allowed pipelines # otherwise, check for additional allowed pipelines
if group == "img2img": if group == "img2img":
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]: if pipeline in [
"controlnet",
"img2img-sdxl",
"lpw",
"panorama",
"panorama-sdxl",
"pix2pix",
]:
return pipeline return pipeline
elif pipeline == "txt2img-sdxl": elif pipeline == "txt2img-sdxl":
return "img2img-sdxl" return "img2img-sdxl"
@ -267,7 +282,7 @@ class ImageParams:
if pipeline in ["controlnet", "lpw", "panorama"]: if pipeline in ["controlnet", "lpw", "panorama"]:
return pipeline return pipeline
elif group == "txt2img": elif group == "txt2img":
if pipeline in ["lpw", "panorama", "txt2img-sdxl"]: if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]:
return pipeline return pipeline
logger.debug("pipeline %s is not valid for %s", pipeline, group) logger.debug("pipeline %s is not valid for %s", pipeline, group)
@ -280,7 +295,7 @@ class ImageParams:
return self.pipeline == "lpw" return self.pipeline == "lpw"
def is_panorama(self): def is_panorama(self):
return self.pipeline == "panorama" return self.pipeline in ["panorama", "panorama-sdxl"]
def is_pix2pix(self): def is_pix2pix(self):
return self.pipeline == "pix2pix" return self.pipeline == "pix2pix"
@ -305,9 +320,11 @@ class ImageParams:
"input_negative_prompt": self.input_negative_prompt, "input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback, "loopback": self.loopback,
"tiled_vae": self.tiled_vae, "tiled_vae": self.tiled_vae,
"tiles": self.tiles, "unet_overlap": self.unet_overlap,
"overlap": self.overlap, "unet_tile": self.unet_tile,
"stride": self.stride, "vae_overlap": self.vae_overlap,
"vae_tile": self.vae_tile,
"denoise": self.denoise,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -327,9 +344,11 @@ class ImageParams:
kwargs.get("input_negative_prompt", self.input_negative_prompt), kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback), kwargs.get("loopback", self.loopback),
kwargs.get("tiled_vae", self.tiled_vae), kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles), kwargs.get("unet_overlap", self.unet_overlap),
kwargs.get("overlap", self.overlap), kwargs.get("unet_tile", self.unet_tile),
kwargs.get("stride", self.stride), kwargs.get("vae_overlap", self.vae_overlap),
kwargs.get("vae_tile", self.vae_tile),
kwargs.get("denoise", self.denoise),
) )
@ -351,6 +370,17 @@ class StageParams:
self.tile_order = tile_order self.tile_order = tile_order
self.tile_size = tile_size self.tile_size = tile_size
def with_args(
self,
**kwargs,
):
return StageParams(
name=kwargs.get("name", self.name),
outscale=kwargs.get("outscale", self.outscale),
tile_order=kwargs.get("tile_order", self.tile_order),
tile_size=kwargs.get("tile_size", self.tile_size),
)
class UpscaleParams: class UpscaleParams:
def __init__( def __init__(
@ -459,10 +489,14 @@ class HighresParams:
self.method = method self.method = method
self.iterations = iterations self.iterations = iterations
def outscale(self) -> int:
return self.scale**self.iterations
def resize(self, size: Size) -> Size: def resize(self, size: Size) -> Size:
outscale = self.outscale()
return Size( return Size(
size.width * (self.scale**self.iterations), size.width * outscale,
size.height * (self.scale**self.iterations), size.height * outscale,
) )
def tojson(self): def tojson(self):

View File

@ -1,12 +1,14 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, Dict
from flask import Flask, jsonify, make_response, request, url_for from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate from jsonschema import validate
from PIL import Image from PIL import Image
from ..chain import CHAIN_STAGES, ChainPipeline from ..chain import CHAIN_STAGES, ChainPipeline
from ..chain.result import StageResult
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.run import ( from ..diffusers.run import (
run_blend_pipeline, run_blend_pipeline,
@ -17,7 +19,7 @@ from ..diffusers.run import (
) )
from ..diffusers.utils import replace_wildcards from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name from ..output import json_params, make_output_name
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline from ..transformers.run import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
base_join, base_join,
@ -49,10 +51,11 @@ from .load import (
get_wildcard_data, get_wildcard_data,
) )
from .params import ( from .params import (
border_from_request, build_border,
highres_from_request, build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request, pipeline_from_request,
upscale_from_request,
) )
from .utils import wrap_route from .utils import wrap_route
@ -167,8 +170,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
size = Size(source.width, source.height) size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img") device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
source_filter = get_from_list( source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys()) request.args, "sourceFilter", list(get_source_filters().keys())
) )
@ -216,12 +219,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img") device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size) output = make_output_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0] job_name = output[0]
pool.submit( pool.submit(
@ -250,7 +253,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
if mask_file is None: if mask_file is None:
return error_reply("mask image is required") return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGBA")
size = Size(source.width, source.height) size = Size(source.width, source.height)
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
@ -270,9 +273,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
) )
device, params, _size = pipeline_from_request(server, "inpaint") device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request() expand = build_border()
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
fill_color = get_not_empty(request.args, "fillColor", "white") fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -340,8 +343,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
@ -366,47 +369,70 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]
def chain(server: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug( if request.is_json:
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys() logger.debug("chain pipeline request with JSON body")
) data = request.get_json()
body = request.form.get("chain") or request.files.get("chain") else:
if body is None: logger.debug(
return error_reply("chain pipeline must have a body") "chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = load_config_str(body)
data = load_config_str(body)
schema = load_config("./schemas/chain.yaml") schema = load_config("./schemas/chain.yaml")
logger.debug("validating chain request: %s against %s", data, schema) logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters device, base_params, base_size = pipeline_from_json(
device, params, size = pipeline_from_request(server) server, data=data.get("defaults")
output = make_output_name(server, "chain", params, size) )
job_name = output[0]
replace_wildcards(params, get_wildcard_data())
# start building the pipeline
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get("stages", []): for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")] stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {}) kwargs: Dict[str, Any] = stage_data.get("params", {})
logger.info("request stage: %s, %s", stage_class.__name__, kwargs) logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
# TODO: combine base params with stage params
_device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
# remove parsed keys, like model names (which become paths)
for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
kwargs.pop(pop_key)
if "seed" in kwargs and kwargs["seed"] == -1:
kwargs.pop("seed")
# replace kwargs with parsed versions
kwargs["params"] = params
kwargs["size"] = size
border = build_border(kwargs)
kwargs["border"] = border
upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale
# prepare the stage metadata
stage = StageParams( stage = StageParams(
stage_data.get("name", stage_class.__name__), stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")), tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
) )
if "border" in kwargs: # load any images related to this stage
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
stage_source_name = "source:%s" % (stage.name) stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name) stage_mask_name = "mask:%s" % (stage.name)
@ -436,20 +462,25 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
)
job_name = output[0]
# build and run chain pipeline # build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
pool.submit( pool.submit(
job_name, job_name,
pipeline, pipeline,
server, server,
params, base_params,
empty_source, StageResult.empty(),
output=output[0], output=output,
size=size, size=base_size,
needs_device=device, needs_device=device,
) )
return jsonify(json_params(output, params, size)) step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -471,7 +502,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
sources.append(source) sources.append(source)
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = build_upscale()
output = make_output_name(server, "upscale", params, size) output = make_output_name(server, "upscale", params, size)
job_name = output[0] job_name = output[0]

View File

@ -5,18 +5,44 @@ from typing import List, Optional
import torch import torch
from ..utils import get_boolean from ..utils import get_boolean, get_list
from .model_cache import ModelCache from .model_cache import ModelCache
logger = getLogger(__name__) logger = getLogger(__name__)
DEFAULT_ANY_PLATFORM = True
DEFAULT_CACHE_LIMIT = 5 DEFAULT_CACHE_LIMIT = 5
DEFAULT_JOB_LIMIT = 10 DEFAULT_JOB_LIMIT = 10
DEFAULT_IMAGE_FORMAT = "png" DEFAULT_IMAGE_FORMAT = "png"
DEFAULT_SERVER_VERSION = "v0.10.0" DEFAULT_SERVER_VERSION = "v0.10.0"
DEFAULT_SHOW_PROGRESS = True
DEFAULT_WORKER_RETRIES = 3
class ServerContext: class ServerContext:
bundle_path: str
model_path: str
output_path: str
params_path: str
cors_origin: str
any_platform: bool
block_platforms: List[str]
default_platform: str
image_format: str
cache_limit: int
cache_path: str
show_progress: bool
optimizations: List[str]
extra_models: List[str]
job_limit: int
memory_limit: int
admin_token: str
server_version: str
worker_retries: int
feature_flags: List[str]
plugins: List[str]
debug: bool
def __init__( def __init__(
self, self,
bundle_path: str = ".", bundle_path: str = ".",
@ -24,19 +50,23 @@ class ServerContext:
output_path: str = ".", output_path: str = ".",
params_path: str = ".", params_path: str = ".",
cors_origin: str = "*", cors_origin: str = "*",
any_platform: bool = True, any_platform: bool = DEFAULT_ANY_PLATFORM,
block_platforms: Optional[List[str]] = None, block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None, default_platform: Optional[str] = None,
image_format: str = DEFAULT_IMAGE_FORMAT, image_format: str = DEFAULT_IMAGE_FORMAT,
cache_limit: int = DEFAULT_CACHE_LIMIT, cache_limit: int = DEFAULT_CACHE_LIMIT,
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
show_progress: bool = True, show_progress: bool = DEFAULT_SHOW_PROGRESS,
optimizations: Optional[List[str]] = None, optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None, extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT, job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None, memory_limit: Optional[int] = None,
admin_token: Optional[str] = None, admin_token: Optional[str] = None,
server_version: Optional[str] = DEFAULT_SERVER_VERSION, server_version: Optional[str] = DEFAULT_SERVER_VERSION,
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
feature_flags: Optional[List[str]] = None,
plugins: Optional[List[str]] = None,
debug: bool = False,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -56,6 +86,10 @@ class ServerContext:
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.admin_token = admin_token or token_urlsafe() self.admin_token = admin_token or token_urlsafe()
self.server_version = server_version self.server_version = server_version
self.worker_retries = worker_retries
self.feature_flags = feature_flags or []
self.plugins = plugins or []
self.debug = debug
self.cache = ModelCache(self.cache_limit) self.cache = ModelCache(self.cache_limit)
@ -72,26 +106,41 @@ class ServerContext:
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), any_platform=get_boolean(
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), ),
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True), show_progress=get_boolean(
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), ),
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit, memory_limit=memory_limit,
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=environ.get( server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
), ),
worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
) )
def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags
def has_optimization(self, opt: str) -> bool:
return opt in self.optimizations
def torch_dtype(self): def torch_dtype(self):
if "torch-fp16" in self.optimizations: if self.has_optimization("torch-fp16"):
return torch.float16 return torch.float16
else: else:
return torch.float32 return torch.float32

View File

@ -134,25 +134,44 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
def apply_patch_basicsr(server: ServerContext): def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module") logger.debug("patching BasicSR module")
import basicsr.utils.download_util try:
import basicsr.utils.download_util
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import basicsr utils for patching")
except AttributeError:
logger.warning("unable to patch basicsr utils")
def apply_patch_codeformer(server: ServerContext): def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module") logger.debug("patching CodeFormer module")
import codeformer.facelib.utils.misc try:
import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) codeformer.facelib.utils.misc.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import codeformer utils for patching")
except AttributeError:
logger.warning("unable to patch codeformer utils")
def apply_patch_facexlib(server: ServerContext): def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module") logger.debug("patching Facexlib module")
import facexlib.utils try:
import facexlib.utils
facexlib.utils.load_file_from_url = partial(patch_cache_path, server) facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import facexlib for patching")
except AttributeError:
logger.warning("unable to patch facexlib utils")
def apply_patches(server: ServerContext): def apply_patches(server: ServerContext):

View File

@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models # Loaded from extra_models
extra_hashes: Dict[str, str] = {} extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {} extra_strings: Dict[str, Any] = {}
extra_tokens: Dict[str, List[str]] = {}
def get_config_params(): def get_config_params():
@ -160,9 +161,10 @@ def load_extras(server: ServerContext):
""" """
global extra_hashes global extra_hashes
global extra_strings global extra_strings
global extra_tokens
labels = {} labels: Dict[str, str] = {}
strings = {} strings: Dict[str, Any] = {}
extra_schema = load_config("./schemas/extras.yaml") extra_schema = load_config("./schemas/extras.yaml")
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
else: else:
labels[model_name] = model["label"] labels[model_name] = model["label"]
if "tokens" in model:
logger.debug(
"collecting tokens for model %s from %s",
model_name,
file,
)
extra_tokens[model_name] = model["tokens"]
if "inversions" in model: if "inversions" in model:
for inversion in model["inversions"]: for inversion in model["inversions"]:
if "label" in inversion: if "label" in inversion:
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
) )
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models) logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend( network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models] [
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
for model in inversion_models
]
) )
lora_models = list_model_globs( lora_models = list_model_globs(
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
base_path=path.join(server.model_path, "lora"), base_path=path.join(server.model_path, "lora"),
) )
logger.debug("loaded LoRA models from disk: %s", lora_models) logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models]) network_models.extend(
[
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
for model in lora_models
]
)
def load_params(server: ServerContext) -> None: def load_params(server: ServerContext) -> None:
@ -397,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
): ):
if potential == "cuda" or potential == "rocm": if potential == "cuda" or potential == "rocm":
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
options = { options: Dict[str, Union[int, str]] = {
"device_id": i, "device_id": i,
} }

View File

@ -51,7 +51,7 @@ class ModelCache:
return return
for i in range(len(cache)): for i in range(len(cache)):
t, k, v = cache[i] t, k, _v = cache[i]
if tag == t and key != k: if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key) logger.debug("updating model cache: %s %s", tag, key)
cache[i] = (tag, key, value) cache[i] = (tag, key, value)

View File

@ -1,10 +1,10 @@
from logging import getLogger from logging import getLogger
from typing import Tuple from typing import Dict, Optional, Tuple
import numpy as np
from flask import request from flask import request
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.utils import random_seed
from ..params import ( from ..params import (
Border, Border,
DeviceParams, DeviceParams,
@ -34,143 +34,122 @@ from .utils import get_model_path
logger = getLogger(__name__) logger = getLogger(__name__)
def pipeline_from_request( def build_device(
server: ServerContext, _server: ServerContext,
default_pipeline: str = "txt2img", data: Dict[str, str],
) -> Tuple[DeviceParams, ImageParams, Size]: ) -> Optional[DeviceParams]:
user = request.remote_addr
# platform stuff # platform stuff
device = None device = None
device_name = request.args.get("platform") device_name = data.get("platform")
if device_name is not None and device_name != "any": if device_name is not None and device_name != "any":
for platform in get_available_platforms(): for platform in get_available_platforms():
if platform.device == device_name: if platform.device == device_name:
device = platform device = platform
return device
def build_params(
server: ServerContext,
default_pipeline: str,
data: Dict[str, str],
) -> ImageParams:
# diffusion model # diffusion model
model = get_not_empty(request.args, "model", get_config_value("model")) model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model) model_path = get_model_path(server, model)
control = None control = None
control_name = request.args.get("control") control_name = data.get("control")
for network in get_network_models(): for network in get_network_models():
if network.name == control_name: if network.name == control_name:
control = network control = network
# pipeline stuff # pipeline stuff
pipeline = get_from_list( pipeline = get_from_list(
request.args, "pipeline", get_available_pipelines(), default_pipeline data, "pipeline", get_available_pipelines(), default_pipeline
) )
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers()) scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
if scheduler is None: if scheduler is None:
scheduler = get_config_value("scheduler") scheduler = get_config_value("scheduler")
# prompt does not come from config # prompt does not come from config
prompt = request.args.get("prompt", "") prompt = data.get("prompt", "")
negative_prompt = request.args.get("negativePrompt", None) negative_prompt = data.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "": if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None negative_prompt = None
# image params # image params
batch = get_and_clamp_int( batch = get_and_clamp_int(
request.args, data,
"batch", "batch",
get_config_value("batch"), get_config_value("batch"),
get_config_value("batch", "max"), get_config_value("batch", "max"),
get_config_value("batch", "min"), get_config_value("batch", "min"),
) )
cfg = get_and_clamp_float( cfg = get_and_clamp_float(
request.args, data,
"cfg", "cfg",
get_config_value("cfg"), get_config_value("cfg"),
get_config_value("cfg", "max"), get_config_value("cfg", "max"),
get_config_value("cfg", "min"), get_config_value("cfg", "min"),
) )
eta = get_and_clamp_float( eta = get_and_clamp_float(
request.args, data,
"eta", "eta",
get_config_value("eta"), get_config_value("eta"),
get_config_value("eta", "max"), get_config_value("eta", "max"),
get_config_value("eta", "min"), get_config_value("eta", "min"),
) )
loopback = get_and_clamp_int( loopback = get_and_clamp_int(
request.args, data,
"loopback", "loopback",
get_config_value("loopback"), get_config_value("loopback"),
get_config_value("loopback", "max"), get_config_value("loopback", "max"),
get_config_value("loopback", "min"), get_config_value("loopback", "min"),
) )
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, data,
"steps", "steps",
get_config_value("steps"), get_config_value("steps"),
get_config_value("steps", "max"), get_config_value("steps", "max"),
get_config_value("steps", "min"), get_config_value("steps", "min"),
) )
height = get_and_clamp_int( tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
request.args, unet_overlap = get_and_clamp_float(
"height", data,
get_config_value("height"), "unet_overlap",
get_config_value("height", "max"), get_config_value("unet_overlap"),
get_config_value("height", "min"), get_config_value("unet_overlap", "max"),
get_config_value("unet_overlap", "min"),
) )
width = get_and_clamp_int( unet_tile = get_and_clamp_int(
request.args, data,
"width", "unet_tile",
get_config_value("width"), get_config_value("unet_tile"),
get_config_value("width", "max"), get_config_value("unet_tile", "max"),
get_config_value("width", "min"), get_config_value("unet_tile", "min"),
) )
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE")) vae_overlap = get_and_clamp_float(
tiles = get_and_clamp_int( data,
request.args, "vae_overlap",
"tiles", get_config_value("vae_overlap"),
get_config_value("tiles"), get_config_value("vae_overlap", "max"),
get_config_value("tiles", "max"), get_config_value("vae_overlap", "min"),
get_config_value("tiles", "min"),
) )
overlap = get_and_clamp_float( vae_tile = get_and_clamp_int(
request.args, data,
"overlap", "vae_tile",
get_config_value("overlap"), get_config_value("vae_tile"),
get_config_value("overlap", "max"), get_config_value("vae_tile", "max"),
get_config_value("overlap", "min"), get_config_value("vae_tile", "min"),
)
stride = get_and_clamp_int(
request.args,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
) )
if stride > tiles: seed = int(data.get("seed", -1))
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles
seed = int(request.args.get("seed", -1))
if seed == -1: if seed == -1:
# this one can safely use np.random because it produces a single value seed = random_seed()
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams( params = ImageParams(
model_path, model_path,
@ -186,38 +165,65 @@ def pipeline_from_request(
control=control, control=control,
loopback=loopback, loopback=loopback,
tiled_vae=tiled_vae, tiled_vae=tiled_vae,
tiles=tiles, unet_overlap=unet_overlap,
overlap=overlap, unet_tile=unet_tile,
stride=stride, vae_overlap=vae_overlap,
vae_tile=vae_tile,
) )
size = Size(width, height)
return (device, params, size) return params
def border_from_request() -> Border: def build_size(
_server: ServerContext,
data: Dict[str, str],
) -> Size:
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
return Size(width, height)
def build_border(
data: Dict[str, str] = None,
) -> Border:
if data is None:
data = request.args
left = get_and_clamp_int( left = get_and_clamp_int(
request.args, data,
"left", "left",
get_config_value("left"), get_config_value("left"),
get_config_value("left", "max"), get_config_value("left", "max"),
get_config_value("left", "min"), get_config_value("left", "min"),
) )
right = get_and_clamp_int( right = get_and_clamp_int(
request.args, data,
"right", "right",
get_config_value("right"), get_config_value("right"),
get_config_value("right", "max"), get_config_value("right", "max"),
get_config_value("right", "min"), get_config_value("right", "min"),
) )
top = get_and_clamp_int( top = get_and_clamp_int(
request.args, data,
"top", "top",
get_config_value("top"), get_config_value("top"),
get_config_value("top", "max"), get_config_value("top", "max"),
get_config_value("top", "min"), get_config_value("top", "min"),
) )
bottom = get_and_clamp_int( bottom = get_and_clamp_int(
request.args, data,
"bottom", "bottom",
get_config_value("bottom"), get_config_value("bottom"),
get_config_value("bottom", "max"), get_config_value("bottom", "max"),
@ -227,46 +233,51 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom) return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams: def build_upscale(
data: Dict[str, str] = None,
) -> UpscaleParams:
if data is None:
data = request.args
denoise = get_and_clamp_float( denoise = get_and_clamp_float(
request.args, data,
"denoise", "denoise",
get_config_value("denoise"), get_config_value("denoise"),
get_config_value("denoise", "max"), get_config_value("denoise", "max"),
get_config_value("denoise", "min"), get_config_value("denoise", "min"),
) )
scale = get_and_clamp_int( scale = get_and_clamp_int(
request.args, data,
"scale", "scale",
get_config_value("scale"), get_config_value("scale"),
get_config_value("scale", "max"), get_config_value("scale", "max"),
get_config_value("scale", "min"), get_config_value("scale", "min"),
) )
outscale = get_and_clamp_int( outscale = get_and_clamp_int(
request.args, data,
"outscale", "outscale",
get_config_value("outscale"), get_config_value("outscale"),
get_config_value("outscale", "max"), get_config_value("outscale", "max"),
get_config_value("outscale", "min"), get_config_value("outscale", "min"),
) )
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(request.args, "correction", get_correction_models()) correction = get_from_list(data, "correction", get_correction_models())
faces = get_not_empty(request.args, "faces", "false") == "true" faces = get_not_empty(data, "faces", "false") == "true"
face_outscale = get_and_clamp_int( face_outscale = get_and_clamp_int(
request.args, data,
"faceOutscale", "faceOutscale",
get_config_value("faceOutscale"), get_config_value("faceOutscale"),
get_config_value("faceOutscale", "max"), get_config_value("faceOutscale", "max"),
get_config_value("faceOutscale", "min"), get_config_value("faceOutscale", "min"),
) )
face_strength = get_and_clamp_float( face_strength = get_and_clamp_float(
request.args, data,
"faceStrength", "faceStrength",
get_config_value("faceStrength"), get_config_value("faceStrength"),
get_config_value("faceStrength", "max"), get_config_value("faceStrength", "max"),
get_config_value("faceStrength", "min"), get_config_value("faceStrength", "min"),
) )
upscale_order = request.args.get("upscaleOrder", "correction-first") upscale_order = data.get("upscaleOrder", "correction-first")
return UpscaleParams( return UpscaleParams(
upscaling, upscaling,
@ -282,37 +293,43 @@ def upscale_from_request() -> UpscaleParams:
) )
def highres_from_request() -> HighresParams: def build_highres(
enabled = get_boolean(request.args, "highres", get_config_value("highres")) data: Dict[str, str] = None,
) -> HighresParams:
if data is None:
data = request.args
enabled = get_boolean(data, "highres", get_config_value("highres"))
iterations = get_and_clamp_int( iterations = get_and_clamp_int(
request.args, data,
"highresIterations", "highresIterations",
get_config_value("highresIterations"), get_config_value("highresIterations"),
get_config_value("highresIterations", "max"), get_config_value("highresIterations", "max"),
get_config_value("highresIterations", "min"), get_config_value("highresIterations", "min"),
) )
method = get_from_list(request.args, "highresMethod", get_highres_methods()) method = get_from_list(data, "highresMethod", get_highres_methods())
scale = get_and_clamp_int( scale = get_and_clamp_int(
request.args, data,
"highresScale", "highresScale",
get_config_value("highresScale"), get_config_value("highresScale"),
get_config_value("highresScale", "max"), get_config_value("highresScale", "max"),
get_config_value("highresScale", "min"), get_config_value("highresScale", "min"),
) )
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, data,
"highresSteps", "highresSteps",
get_config_value("highresSteps"), get_config_value("highresSteps"),
get_config_value("highresSteps", "max"), get_config_value("highresSteps", "max"),
get_config_value("highresSteps", "min"), get_config_value("highresSteps", "min"),
) )
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, data,
"highresStrength", "highresStrength",
get_config_value("highresStrength"), get_config_value("highresStrength"),
get_config_value("highresStrength", "max"), get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"), get_config_value("highresStrength", "min"),
) )
return HighresParams( return HighresParams(
enabled, enabled,
scale, scale,
@ -321,3 +338,50 @@ def highres_from_request() -> HighresParams:
method=method, method=method,
iterations=iterations, iterations=iterations,
) )
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
def pipeline_from_json(
server: ServerContext,
data: Dict[str, str],
default_pipeline: str = "txt2img",
) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
params.steps,
params.scheduler,
params.model,
params.pipeline,
device or "any device",
size.width,
size.height,
params.cfg,
params.seed,
params.prompt,
)
return (device, params, size)

View File

@ -0,0 +1,61 @@
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
self.pipelines = pipelines or {}
self.stages = stages or {}
PluginModule = Callable[[ServerContext], PluginExports]
def load_plugins(server: ServerContext) -> PluginExports:
combined_exports = PluginExports()
for plugin in server.plugins:
logger.info("loading plugin module: %s", plugin)
try:
module: PluginModule = import_module(plugin)
exports = module(server)
for name, pipeline in exports.pipelines.items():
if name in combined_exports.pipelines:
logger.warning(
"multiple plugins exported a pipeline named %s", name
)
else:
combined_exports.pipelines[name] = pipeline
for name, stage in exports.stages.items():
if name in combined_exports.stages:
logger.warning("multiple plugins exported a stage named %s", name)
else:
combined_exports.stages[name] = stage
except Exception:
logger.exception("error importing plugin")
return combined_exports
def register_plugins(exports: PluginExports) -> bool:
success = True
for name, pipeline in exports.pipelines.items():
success = success and add_pipeline(name, pipeline)
for name, stage in exports.stages.items():
success = success and add_stage(name, stage)
return success

View File

@ -18,6 +18,11 @@ logger = getLogger(__name__)
SAFE_CHARS = "._-" SAFE_CHARS = "._-"
def split_list(val: str) -> List[str]:
parts = [part.strip() for part in val.split(",")]
return [part for part in parts if len(part) > 0]
def base_join(base: str, tail: str) -> str: def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
return path.join(base, tail_path) return path.join(base, tail_path)
@ -28,7 +33,16 @@ def is_debug() -> bool:
def get_boolean(args: Any, key: str, default_value: bool) -> bool: def get_boolean(args: Any, key: str, default_value: bool) -> bool:
return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes") val = args.get(key, str(default_value))
if isinstance(val, bool):
return val
return val.lower() in ("1", "t", "true", "y", "yes")
def get_list(args: Any, key: str, default="") -> List[str]:
return split_list(args.get(key, default))
def get_and_clamp_float( def get_and_clamp_float(
@ -61,13 +75,13 @@ def get_from_list(
def get_from_map( def get_from_map(
args: Any, key: str, values: Dict[str, TElem], default: TElem args: Any, key: str, values: Dict[str, TElem], default_key: str
) -> TElem: ) -> TElem:
selected = args.get(key, default) selected = args.get(key, default_key)
if selected in values: if selected in values:
return values[selected] return values[selected]
else: else:
return values[default] return values[default_key]
def get_not_empty(args: Any, key: str, default: TElem) -> TElem: def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
@ -195,6 +209,8 @@ def load_config(file: str) -> Dict:
return load_yaml(file) return load_yaml(file)
elif ext in [".json"]: elif ext in [".json"]:
return load_json(file) return load_json(file)
else:
raise ValueError("unknown config file extension")
def load_config_str(raw: str) -> Dict: def load_config_str(raw: str) -> Dict:

View File

@ -25,6 +25,7 @@ class WorkerContext:
idle: "Value[bool]" idle: "Value[bool]"
timeout: float timeout: float
retries: int retries: int
initial_retries: int
def __init__( def __init__(
self, self,
@ -36,6 +37,8 @@ class WorkerContext:
progress: "Queue[ProgressCommand]", progress: "Queue[ProgressCommand]",
active_pid: "Value[int]", active_pid: "Value[int]",
idle: "Value[bool]", idle: "Value[bool]",
retries: int,
timeout: float,
): ):
self.job = None self.job = None
self.name = name self.name = name
@ -47,12 +50,13 @@ class WorkerContext:
self.active_pid = active_pid self.active_pid = active_pid
self.last_progress = None self.last_progress = None
self.idle = idle self.idle = idle
self.timeout = 1.0 self.initial_retries = retries
self.retries = 3 # TODO: get from env self.retries = retries
self.timeout = timeout
def start(self, job: str) -> None: def start(self, job: str) -> None:
self.job = job self.job = job
self.retries = 3 self.retries = self.initial_retries
self.set_cancel(cancel=False) self.set_cancel(cancel=False)
self.set_idle(idle=False) self.set_idle(idle=False)
@ -82,7 +86,7 @@ class WorkerContext:
return 0 return 0
def get_progress_callback(self) -> ProgressCallback: def get_progress_callback(self) -> ProgressCallback:
from ..chain.base import ChainProgress from ..chain.pipeline import ChainProgress
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step on_progress.step = step

View File

@ -86,15 +86,15 @@ class DevicePoolExecutor:
self.logs = Queue(self.max_pending_per_worker) self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock() self.rlock = Lock()
def start(self) -> None: def start(self, *args) -> None:
self.create_health_worker() self.create_health_worker()
self.create_logger_worker() self.create_logger_worker()
self.create_progress_worker() self.create_progress_worker()
for device in self.devices: for device in self.devices:
self.create_device_worker(device) self.create_device_worker(device, *args)
def create_device_worker(self, device: DeviceParams) -> None: def create_device_worker(self, device: DeviceParams, *args) -> None:
name = device.device name = device.device
# always recreate queues # always recreate queues
@ -124,15 +124,17 @@ class DevicePoolExecutor:
pending=self.pending[name], pending=self.pending[name],
active_pid=current, active_pid=current,
idle=self.worker_idle[name], idle=self.worker_idle[name],
retries=self.server.worker_retries,
timeout=self.progress_interval,
) )
self.context[name] = context self.context[name] = context
worker = Process( worker = Process(
name=f"onnx-web worker: {name}", name=f"onnx-web worker: {name}",
target=worker_main, target=worker_main,
args=(context, self.server), args=(context, self.server, *args),
daemon=True,
) )
worker.daemon = True
self.workers[name] = worker self.workers[name] = worker
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)

View File

@ -27,10 +27,14 @@ MEMORY_ERRORS = [
] ]
def worker_main(worker: WorkerContext, server: ServerContext): def worker_main(
apply_patches(server) worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
):
setproctitle("onnx-web worker: %s" % (worker.device.device)) setproctitle("onnx-web worker: %s" % (worker.device.device))
if patch:
apply_patches(server)
logger.trace( logger.trace(
"checking in from worker with providers: %s", get_available_providers() "checking in from worker with providers: %s", get_available_providers()
) )
@ -46,7 +50,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
getpid(), getpid(),
worker.get_active(), worker.get_active(),
) )
exit(EXIT_REPLACED) return exit(EXIT_REPLACED)
# wait briefly for the next job # wait briefly for the next job
job = worker.pending.get(timeout=worker.timeout) job = worker.pending.get(timeout=worker.timeout)
@ -69,15 +73,15 @@ def worker_main(worker: WorkerContext, server: ServerContext):
except KeyboardInterrupt: except KeyboardInterrupt:
logger.debug("worker got keyboard interrupt") logger.debug("worker got keyboard interrupt")
worker.fail() worker.fail()
exit(EXIT_INTERRUPT) return exit(EXIT_INTERRUPT)
except RetryException: except RetryException:
logger.exception("retry error in worker, exiting") logger.exception("retry error in worker, exiting")
worker.fail() worker.fail()
exit(EXIT_ERROR) return exit(EXIT_ERROR)
except ValueError: except ValueError:
logger.exception("value error in worker, exiting") logger.exception("value error in worker, exiting")
worker.fail() worker.fail()
exit(EXIT_ERROR) return exit(EXIT_ERROR)
except Exception as e: except Exception as e:
e_str = str(e) e_str = str(e)
# restart the worker on memory errors # restart the worker on memory errors
@ -85,7 +89,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
if e_mem in e_str: if e_mem in e_str:
logger.error("detected out-of-memory error, exiting: %s", e) logger.error("detected out-of-memory error, exiting: %s", e)
worker.fail() worker.fail()
exit(EXIT_MEMORY) return exit(EXIT_MEMORY)
# carry on for other errors # carry on for other errors
logger.exception( logger.exception(

View File

@ -98,7 +98,7 @@
"highresSteps": { "highresSteps": {
"default": 0, "default": 0,
"min": 1, "min": 1,
"max": 200, "max": 500,
"step": 1 "step": 1
}, },
"highresStrength": { "highresStrength": {
@ -141,12 +141,6 @@
"max": 4, "max": 4,
"step": 1 "step": 1
}, },
"overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"pipeline": { "pipeline": {
"default": "", "default": "",
"keys": [ "keys": [
@ -188,7 +182,7 @@
"steps": { "steps": {
"default": 25, "default": 25,
"min": 1, "min": 1,
"max": 200, "max": 300,
"step": 1 "step": 1
}, },
"strength": { "strength": {
@ -197,21 +191,9 @@
"max": 1, "max": 1,
"step": 0.01 "step": 0.01
}, },
"stride": { "tiled_vae": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": {
"default": false "default": false
}, },
"tiles": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"tileOrder": { "tileOrder": {
"default": "spiral", "default": "spiral",
"keys": [ "keys": [
@ -225,6 +207,18 @@
"max": 1024, "max": 1024,
"step": 8 "step": 8
}, },
"unet_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"unet_tile": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"upscaleOrder": { "upscaleOrder": {
"default": "correction-first", "default": "correction-first",
"keys": [ "keys": [
@ -237,6 +231,18 @@
"default": "", "default": "",
"keys": [] "keys": []
}, },
"vae_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"vae_tile": {
"default": 512,
"min": 256,
"max": 1024,
"step": 128
},
"width": { "width": {
"default": 512, "default": 512,
"min": 128, "min": 128,

View File

@ -9,12 +9,14 @@ skip_glob = ["*/lpw.py"]
[tool.mypy] [tool.mypy]
# ignore_missing_imports = true # ignore_missing_imports = true
exclude = [ exclude = [
"onnx_web.diffusers.lpw_stable_diffusion_onnx" "onnx_web.diffusers.pipelines.controlnet",
"onnx_web.diffusers.pipelines.lpw",
"onnx_web.diffusers.pipelines.pix2pix"
] ]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [
"arpeggio", "arpeggio",
"basicsr.archs.rrdbnet_arch", "basicsr.archs.rrdbnet_arch",
"basicsr.utils.download_util", "basicsr.utils.download_util",
"basicsr.utils", "basicsr.utils",
@ -27,8 +29,10 @@ module = [
"compel", "compel",
"controlnet_aux", "controlnet_aux",
"cv2", "cv2",
"debugpy",
"diffusers", "diffusers",
"diffusers.configuration_utils", "diffusers.configuration_utils",
"diffusers.image_processor",
"diffusers.loaders", "diffusers.loaders",
"diffusers.models.attention_processor", "diffusers.models.attention_processor",
"diffusers.models.autoencoder_kl", "diffusers.models.autoencoder_kl",
@ -41,9 +45,10 @@ module = [
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.onnx_utils", "diffusers.pipelines.onnx_utils",
"diffusers.pipelines.paint_by_example", "diffusers.pipelines.paint_by_example",
"diffusers.pipelines.pipeline_utils",
"diffusers.pipelines.stable_diffusion", "diffusers.pipelines.stable_diffusion",
"diffusers.pipelines.stable_diffusion.convert_from_ckpt", "diffusers.pipelines.stable_diffusion.convert_from_ckpt",
"diffusers.pipeline_utils", "diffusers.pipelines.stable_diffusion_xl",
"diffusers.schedulers", "diffusers.schedulers",
"diffusers.utils.logging", "diffusers.utils.logging",
"facexlib.utils", "facexlib.utils",
@ -56,11 +61,17 @@ module = [
"mediapipe", "mediapipe",
"onnxruntime", "onnxruntime",
"onnxruntime.transformers.float16", "onnxruntime.transformers.float16",
"optimum.exporters.onnx",
"optimum.onnxruntime",
"optimum.onnxruntime.modeling_diffusion",
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
"optimum.pipelines.diffusers.pipeline_utils",
"piexif", "piexif",
"piexif.helper", "piexif.helper",
"realesrgan", "realesrgan",
"realesrgan.archs.srvgg_arch", "realesrgan.archs.srvgg_arch",
"safetensors", "safetensors",
"scipy",
"timm.models.layers", "timm.models.layers",
"transformers", "transformers",
"win10toast" "win10toast"

View File

@ -46,17 +46,31 @@ $defs:
patternProperties: patternProperties:
"^[-_A-Za-z]+$": "^[-_A-Za-z]+$":
oneOf: oneOf:
- type: boolean
- type: number - type: number
- type: string - type: string
- type: "null"
request_chain: request_chain:
type: array type: array
items: items:
$ref: "#/$defs/request_stage" $ref: "#/$defs/request_stage"
request_defaults:
type: object
properties:
txt2img:
$ref: "#/$defs/image_params"
img2img:
$ref: "#/$defs/image_params"
type: object type: object
additionalProperties: False additionalProperties: False
required: [stages] required: [stages]
properties: properties:
defaults:
$ref: "#/$defs/request_defaults"
platform:
type: string
stages: stages:
$ref: "#/$defs/request_chain" $ref: "#/$defs/request_chain"

View File

@ -10,34 +10,53 @@ $defs:
- type: number - type: number
- type: string - type: string
lora_network: tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
embedding_network:
type: object type: object
required: [name, source] required: [name, source]
properties: properties:
name: format:
type: string $ref: "#/$defs/tensor_format"
source:
type: string
label: label:
type: string type: string
weight: model:
type: number
textual_inversion_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
type: string type: string
enum: [concept, embeddings] enum: [concept, embeddings]
label: name:
type: string
source:
type: string type: string
token: token:
type: string type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number
lora_network:
type: object
required: [name, source, type]
properties:
label:
type: string
model:
type: string
enum: [cloneofsimo, sd-scripts]
name:
type: string
source:
type: string
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight: weight:
type: number type: number
@ -46,8 +65,7 @@ $defs:
required: [name, source] required: [name, source]
properties: properties:
format: format:
type: string $ref: "#/$defs/tensor_format"
enum: [bin, ckpt, onnx, pt, pth, safetensors]
half: half:
type: boolean type: boolean
label: label:
@ -85,7 +103,7 @@ $defs:
inversions: inversions:
type: array type: array
items: items:
$ref: "#/$defs/textual_inversion_network" $ref: "#/$defs/embedding_network"
loras: loras:
type: array type: array
items: items:
@ -100,6 +118,7 @@ $defs:
panorama, panorama,
pix2pix, pix2pix,
txt2img, txt2img,
txt2img-sdxl,
upscale, upscale,
] ]
vae: vae:
@ -141,31 +160,6 @@ $defs:
source: source:
type: string type: string
source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]
translation: translation:
type: object type: object
additionalProperties: False additionalProperties: False
@ -193,7 +187,9 @@ properties:
networks: networks:
type: array type: array
items: items:
$ref: "#/$defs/source_network" oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources: sources:
type: array type: array
items: items:

74
api/scripts/onnx-lora.py Normal file
View File

@ -0,0 +1,74 @@
from argparse import ArgumentParser
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from os import path
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, SessionOptions
from logging import getLogger
from onnx_web.convert.utils import ConversionContext
logger = getLogger(__name__)
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -30,6 +30,10 @@ FAST_TEST = 10
SLOW_TEST = 25 SLOW_TEST = 25
VERY_SLOW_TEST = 75 VERY_SLOW_TEST = 75
STRICT_TEST = 1e-4
LOOSE_TEST = 1e-2
VERY_LOOSE_TEST = 0.025
def test_path(relpath: str) -> str: def test_path(relpath: str) -> str:
return path.join(path.dirname(__file__), relpath) return path.join(path.dirname(__file__), relpath)
@ -41,7 +45,7 @@ class TestCase:
name: str, name: str,
query: str, query: str,
max_attempts: int = FAST_TEST, max_attempts: int = FAST_TEST,
mse_threshold: float = 1e-4, mse_threshold: float = STRICT_TEST,
source: Union[Image.Image, List[Image.Image]] = None, source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None, mask: Image.Image = None,
) -> None: ) -> None:
@ -65,6 +69,7 @@ TEST_DATA = [
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-deis", "txt2img-sd-v1-5-512-muffin-deis",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-dpm", "txt2img-sd-v1-5-512-muffin-dpm",
@ -73,10 +78,12 @@ TEST_DATA = [
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-heun", "txt2img-sd-v1-5-512-muffin-heun",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"txt2img-sd-v1-5-512-muffin-unipc", "txt2img-sd-v1-5-512-muffin-unipc",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi",
mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"txt2img-sd-v2-1-512-muffin", "txt2img-sd-v2-1-512-muffin",
@ -84,7 +91,7 @@ TEST_DATA = [
), ),
TestCase( TestCase(
"txt2img-sd-v2-1-768-muffin", "txt2img-sd-v2-1-768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768",
max_attempts=SLOW_TEST, max_attempts=SLOW_TEST,
), ),
TestCase( TestCase(
@ -106,7 +113,7 @@ TEST_DATA = [
), ),
TestCase( TestCase(
"img2img-sd-v1-5-256-pumpkin", "img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256",
source="txt2img-sd-v1-5-256-muffin-0", source="txt2img-sd-v1-5-256-muffin-0",
), ),
TestCase( TestCase(
@ -130,7 +137,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST, max_attempts=SLOW_TEST,
mse_threshold=0.025, mse_threshold=VERY_LOOSE_TEST,
), ),
TestCase( TestCase(
"outpaint-vertical-512", "outpaint-vertical-512",
@ -141,7 +148,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST, max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"outpaint-horizontal-512", "outpaint-horizontal-512",
@ -152,7 +159,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=SLOW_TEST, max_attempts=SLOW_TEST,
mse_threshold=0.010, mse_threshold=LOOSE_TEST,
), ),
TestCase( TestCase(
"upscale-resrgan-x2-1024-muffin", "upscale-resrgan-x2-1024-muffin",
@ -229,7 +236,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=VERY_SLOW_TEST, max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025, mse_threshold=VERY_LOOSE_TEST,
), ),
TestCase( TestCase(
"outpaint-panorama-vertical-512", "outpaint-panorama-vertical-512",
@ -240,7 +247,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=VERY_SLOW_TEST, max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025, mse_threshold=VERY_LOOSE_TEST,
), ),
TestCase( TestCase(
"outpaint-panorama-horizontal-512", "outpaint-panorama-horizontal-512",
@ -251,7 +258,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black", mask="mask-black",
max_attempts=VERY_SLOW_TEST, max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025, mse_threshold=VERY_LOOSE_TEST,
), ),
TestCase( TestCase(
"upscale-resrgan-x4-codeformer-2048-muffin", "upscale-resrgan-x4-codeformer-2048-muffin",
@ -260,6 +267,7 @@ TEST_DATA = [
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0" "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
), ),
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
), ),
TestCase( TestCase(
"upscale-resrgan-x4-gfpgan-2048-muffin", "upscale-resrgan-x4-gfpgan-2048-muffin",
@ -268,6 +276,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0" "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
), ),
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
), ),
TestCase( TestCase(
"upscale-swinir-x4-codeformer-2048-muffin", "upscale-swinir-x4-codeformer-2048-muffin",
@ -276,6 +285,7 @@ TEST_DATA = [
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0" "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
), ),
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
), ),
TestCase( TestCase(
"upscale-swinir-x4-gfpgan-2048-muffin", "upscale-swinir-x4-gfpgan-2048-muffin",
@ -284,6 +294,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0" "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
), ),
source="txt2img-sd-v1-5-512-muffin-0", source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
), ),
TestCase( TestCase(
"upscale-sd-x4-codeformer-2048-muffin", "upscale-sd-x4-codeformer-2048-muffin",
@ -305,18 +316,18 @@ TEST_DATA = [
), ),
TestCase( TestCase(
"txt2img-panorama-1024x768-muffin", "txt2img-panorama-1024x768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true",
max_attempts=VERY_SLOW_TEST, max_attempts=VERY_SLOW_TEST,
), ),
TestCase( TestCase(
"img2img-panorama-1024x768-pumpkin", "img2img-panorama-1024x768-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true",
source="txt2img-panorama-1024x768-muffin-0", source="txt2img-panorama-1024x768-muffin-0",
max_attempts=VERY_SLOW_TEST, max_attempts=VERY_SLOW_TEST,
), ),
TestCase( TestCase(
"txt2img-sd-v1-5-tall-muffin", "txt2img-sd-v1-5-tall-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768", "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768",
), ),
TestCase( TestCase(
"upscale-resrgan-x4-tall-muffin", "upscale-resrgan-x4-tall-muffin",
@ -325,6 +336,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0" "&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0"
), ),
source="txt2img-sd-v1-5-tall-muffin-0", source="txt2img-sd-v1-5-tall-muffin-0",
max_attempts=SLOW_TEST,
), ),
# TODO: non-square controlnet # TODO: non-square controlnet
] ]
@ -335,6 +347,39 @@ class TestError(Exception):
return super().__str__() return super().__str__()
class TestResult:
error: Optional[str]
mse: Optional[float]
name: str
passed: bool
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
self.error = error
self.mse = mse
self.name = name
self.passed = passed
def __repr__(self) -> str:
if self.passed:
if self.mse is not None:
return f"{self.name} ({self.mse})"
else:
return self.name
else:
if self.mse is not None:
return f"{self.name}: {self.error} ({self.mse})"
else:
return f"{self.name}: {self.error}"
@classmethod
def passed(self, name: str, mse = None):
return TestResult(name, mse=mse)
@classmethod
def failed(self, name: str, error: str, mse = None):
return TestResult(name, error=error, mse=mse, passed=False)
def parse_args(args: List[str]): def parse_args(args: List[str]):
parser = ArgumentParser( parser = ArgumentParser(
prog="onnx-web release tests", prog="onnx-web release tests",
@ -441,14 +486,14 @@ def run_test(
host: str, host: str,
test: TestCase, test: TestCase,
mse_mult: float = 1.0, mse_mult: float = 1.0,
) -> bool: ) -> TestResult:
""" """
Generate an image, wait for it to be ready, and calculate the MSE from the reference. Generate an image, wait for it to be ready, and calculate the MSE from the reference.
""" """
keys = generate_images(host, test) keys = generate_images(host, test)
if keys is None: if keys is None:
raise ValueError("could not generate image") return TestResult.failed(test.name, "could not generate image")
ready = False ready = False
for attempt in tqdm(range(test.max_attempts)): for attempt in tqdm(range(test.max_attempts)):
@ -461,13 +506,13 @@ def run_test(
sleep(6) sleep(6)
if not ready: if not ready:
raise ValueError("image was not ready in time") return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys) results = download_images(host, keys)
if results is None: if results is None or len(results) == 0:
raise ValueError("could not download image") return TestResult.failed(test.name, "could not download image")
passed = True passed = False
for i in range(len(results)): for i in range(len(results)):
result = results[i] result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
@ -476,14 +521,19 @@ def run_test(
ref = Image.open(ref_name) if path.exists(ref_name) else None ref = Image.open(ref_name) if path.exists(ref_name) else None
mse = find_mse(result, ref) mse = find_mse(result, ref)
threshold = test.mse_threshold * mse_mult
if mse < (test.mse_threshold * mse_mult): if mse < threshold:
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold) logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
passed = True
else: else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold) logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
passed = False return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
return passed if passed:
return TestResult.passed(test.name)
else:
return TestResult.failed(test.name, "no images tested")
def main(): def main():
@ -504,24 +554,26 @@ def main():
passed = [] passed = []
failed = [] failed = []
for test in tests: for test in tests:
test_passed = False result = None
for _i in range(3): for _i in range(3):
try: try:
logger.info("starting test: %s", test.name) logger.info("starting test: %s", test.name)
if run_test(args.host, test, mse_mult=args.mse): result = run_test(args.host, test, mse_mult=args.mse)
if result.passed:
logger.info("test passed: %s", test.name) logger.info("test passed: %s", test.name)
test_passed = True
break break
else: else:
logger.warning("test failed: %s", test.name) logger.warning("test failed: %s", test.name)
except Exception: except Exception:
logger.exception("error running test for %s", test.name) logger.exception("error running test for %s", test.name)
result = TestResult.failed(test.name, "TODO: exception message")
if test_passed: if result is not None:
passed.append(test.name) if result.passed:
else: passed.append(result)
failed.append(test.name) else:
failed.append(result)
logger.info("%s of %s tests passed", len(passed), len(tests)) logger.info("%s of %s tests passed", len(passed), len(tests))
failed = list(set(failed)) failed = list(set(failed))

View File

View File

@ -0,0 +1,26 @@
import unittest
from onnx_web.chain.pipeline import ChainProgress
class ChainProgressTests(unittest.TestCase):
def test_accumulate_with_reset(self):
def parent(step, timestep, latents):
pass
progress = ChainProgress(parent)
progress(5, 1, None)
progress(0, 1, None)
progress(5, 1, None)
self.assertEqual(progress.get_total(), 10)
def test_start_value(self):
def parent(step, timestep, latents):
pass
progress = ChainProgress(parent, 5)
self.assertEqual(progress.get_total(), 5)
progress(10, 1, None)
self.assertEqual(progress.get_total(), 10)

View File

@ -0,0 +1,23 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
]
)
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -0,0 +1,47 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.chain.result import StageResult
from onnx_web.params import ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
class BlendImg2ImgStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_stage(self):
stage = BlendImg2ImgStage()
params = ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"euler-a",
"an astronaut eating a hamburger",
3.0,
1,
1,
)
server = ServerContext(model_path="../models", output_path="../outputs")
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
)
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))

View File

@ -0,0 +1,23 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage
from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
)
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -0,0 +1,25 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_mask import BlendMaskStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
class BlendMaskStageTests(unittest.TestCase):
def test_empty(self):
stage = BlendMaskStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
stage_mask=Image.new("RGBA", (64, 64)),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,46 @@
import unittest
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import (
TEST_MODEL_CORRECTION_CODEFORMER,
test_device,
test_needs_models,
)
class CorrectCodeformerStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
stage = CorrectCodeformerStage()
sources = StageResult.empty()
result = stage.run(
worker,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,44 @@
import unittest
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/correction-gfpgan-v1-3"
class CorrectGFPGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
stage = CorrectGFPGANStage()
sources = StageResult.empty()
result = stage.run(
worker,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,24 @@
import unittest
from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceCropStageTests(unittest.TestCase):
def test_empty(self):
stage = ReduceCropStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,28 @@
import unittest
from PIL import Image
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceThumbnailStageTests(unittest.TestCase):
def test_empty(self):
stage_source = Image.new("RGB", (64, 64))
stage = ReduceThumbnailStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
stage_source=stage_source,
)
self.assertEqual(len(result), 0)

Some files were not shown because too many files have changed in this diff Show More