1
0
Fork 0

Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards

This commit is contained in:
BZLibby 2023-12-06 16:29:45 -06:00
commit 17f28aba62
182 changed files with 8803 additions and 2788 deletions

View File

@ -17,12 +17,14 @@ with a CPU fallback capable of running on laptop-class machines.
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png)
![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
## Features
This is an incomplete list of new and interesting features, with links to the user guide:
- SDXL support
- LCM support
- hardware acceleration on both AMD and Nvidia
- tested on CUDA, DirectML, and ROCm
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
@ -37,6 +39,7 @@ This is an incomplete list of new and interesting features, with links to the us
- [txt2img](docs/user-guide.md#txt2img-tab)
- [img2img](docs/user-guide.md#img2img-tab)
- [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload
- [panorama](docs/user-guide.md#panorama-pipeline), for both SD v1.5 and SDXL
- [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration
- [add and use your own models](docs/user-guide.md#adding-your-own-models)
- [convert models from diffusers and SD checkpoints](docs/converting-models.md)
@ -45,20 +48,24 @@ This is an incomplete list of new and interesting features, with links to the us
- [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks)
- [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens)
- [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens)
- each layer of the embeddings can be controlled and used individually
- ControlNet
- image filters for edge detection and other methods
- with ONNX acceleration
- highres mode
- runs img2img on the results of the other pipelines
- multiple iterations can produce 8k images and larger
- [multi-stage](docs/user-guide.md#prompt-stages) and [region prompts](docs/user-guide.md#region-tokens)
- seamlessly combine multiple prompts in the same image
- provide prompts for different areas in the image and blend them together
- change the prompt for highres mode and refine details without recursion
- infinite prompt length
- [with long prompt weighting](docs/user-guide.md#long-prompt-weighting)
- expand and control Textual Inversions per-layer
- [image blending mode](docs/user-guide.md#blend-tab)
- combine images from history
- upscaling and face correction
- upscaling with Real ESRGAN or Stable Diffusion
- face correction with CodeFormer or GFPGAN
- upscaling and correction
- upscaling with Real ESRGAN, SwinIR, and Stable Diffusion
- face correction with CodeFormer and GFPGAN
- [API server can be run remotely](docs/server-admin.md)
- REST API can be served over HTTPS or HTTP
- background processing for all image pipelines
@ -66,7 +73,7 @@ This is an incomplete list of new and interesting features, with links to the us
- OCI containers provided
- for all supported hardware accelerators
- includes both the API and GUI bundle in a single container
- runs well on [RunPod](https://www.runpod.io/) and other GPU container hosting services
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
## Contents

11
api/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}

View File

@ -1,4 +1,4 @@
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload style
onnx_env: ## create virtual env
python -v venv onnx_env
@ -18,9 +18,10 @@ pip-dev: check-venv
test:
python -m coverage erase
python -m coverage run -m unittest discover -s tests/
python -m coverage run -m unittest discover -v -s tests/
python -m coverage html -i
python -m coverage xml -i
python -m coverage report -i
package: package-dist package-upload
@ -32,13 +33,21 @@ package-upload:
lint-check:
black --check onnx_web/
isort --check-only --skip __init__.py --filter-files onnx_web
black --check tests/
flake8 onnx_web
flake8 tests
isort --check-only --skip __init__.py --filter-files onnx_web
isort --check-only --skip __init__.py --filter-files tests
lint-fix:
black onnx_web/
isort --skip __init__.py --filter-files onnx_web
black tests/
flake8 onnx_web
flake8 tests
isort --skip __init__.py --filter-files onnx_web
isort --skip __init__.py --filter-files tests
style: lint-fix
typecheck:
mypy onnx_web

View File

@ -1,45 +1,2 @@
from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
from .pipeline import ChainPipeline, PipelineStage, StageParams
from .stages import * # NOQA

View File

@ -1,240 +1,39 @@
from datetime import timedelta
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from typing import Optional
from PIL import Image
from ..errors import RetryException
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
from .result import StageResult
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> List[Image.Image]:
return self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
_sources: StageResult,
*,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
raise NotImplementedError() # noqa
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def __call__(
def steps(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
_params: ImageParams,
_size: Size,
) -> int:
return 1 # noqa
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
sources = [None]
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources) - stage_sources.count(None),
kwargs.keys(),
)
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if must_tile:
stage_outputs = []
for source in stage_sources:
logger.info(
"image larger than tile size of %s, tiling stage",
tile,
)
def stage_tile(
source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image:
for i in range(worker.retries):
try:
output_tile = stage_pipe.run(
worker,
server,
stage_params,
params,
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)[0]
if is_debug():
save_image(server, "last-tile.png", output_tile)
return output_tile
except Exception:
logger.exception(
"error while running stage pipeline for tile, retry %s of 3",
i,
)
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
for i in range(worker.retries):
try:
stage_outputs = stage_pipe.run(
worker,
server,
stage_params,
params,
stage_sources,
callback=callback,
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_outputs
break
except Exception:
logger.exception(
"error while running stage pipeline, retry %s of 3", i
)
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources
def outputs(
self,
_params: ImageParams,
sources: int,
) -> int:
return sources

View File

@ -0,0 +1,40 @@
from logging import getLogger
from typing import Optional
import cv2
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendDenoiseStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
strength: int = 3,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("denoising source images")
results = []
for source in sources.as_numpy():
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
return StageResult(arrays=results)

View File

@ -0,0 +1,62 @@
from logging import getLogger
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendGridStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
height: int,
width: int,
# rows: Optional[List[str]] = None,
# columns: Optional[List[str]] = None,
# title: Optional[str] = None,
order: Optional[List[int]] = None,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("combining source images using grid layout")
images = sources.as_image()
ref_image = images[0]
size = Size(*ref_image.size)
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
# TODO: labels
if order is None:
order = range(len(images))
for i in range(len(order)):
x = i % width
y = i // width
n = order[i]
output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output])
def outputs(
self,
_params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
import numpy as np
import torch
@ -10,13 +10,14 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendImg2ImgStage(BaseStage):
max_tile = SizeChart.unlimited
max_tile = SizeChart.max
def run(
self,
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
server: ServerContext,
_stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
prompt_index: Optional[int] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
params = params.with_args(**kwargs)
# multi-stage prompting
@ -52,7 +53,7 @@ class BlendImg2ImgStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
pipe_params["strength"] = strength
outputs = []
for source in sources:
for source in sources.as_image():
if params.is_lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
@ -81,11 +82,10 @@ class BlendImg2ImgStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
@ -102,4 +102,18 @@ class BlendImg2ImgStage(BaseStage):
outputs.extend(result.images)
return outputs
return StageResult(images=outputs)
def steps(
self,
params: ImageParams,
*args,
) -> int:
return params.steps # TODO: multiply by strength
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,12 +1,13 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,13 +19,18 @@ class BlendLinearStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
alpha: float,
stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
logger.info("blending source images using linear interpolation")
return [Image.blend(source, stage_source, alpha) for source in sources]
return StageResult(
images=[
Image.blend(source, stage_source, alpha)
for source in sources.as_image()
]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
@ -8,7 +8,8 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -20,16 +21,17 @@ class BlendMaskStage(BaseStage):
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
# TODO: does this need an alpha channel?
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L")
@ -37,4 +39,9 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return [Image.composite(stage_source, source, mult_mask) for source in sources]
return StageResult(
images=[
Image.composite(stage_source, source, mult_mask)
for source in sources.as_image()
]
)

View File

@ -1,12 +1,13 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return [pipe(source) for source in sources]
return StageResult(images=[pipe(source) for source in sources.as_image()])

View File

@ -1,15 +1,15 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -57,12 +57,12 @@ class CorrectGFPGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
upscale = upscale.with_args(**kwargs)
if upscale.correction_model is None:
@ -73,16 +73,15 @@ class CorrectGFPGANStage(BaseStage):
device = worker.get_device()
gfpgan = self.load(server, stage, upscale, device)
outputs = []
for source in sources:
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
outputs = [
gfpgan.enhance(
source,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
outputs.append(Image.fromarray(output, "RGB"))
for source in sources.as_numpy()
]
return outputs
return StageResult(images=outputs)

View File

@ -1,11 +1,11 @@
from logging import getLogger
from typing import Optional
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from .pipeline import ChainPipeline
logger = getLogger(__name__)
@ -43,7 +43,7 @@ def stage_highres(
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
overlap=params.vae_overlap,
)
else:
logger.debug("using simple upscaling for highres")
@ -51,14 +51,14 @@ def stage_highres(
UpscaleSimpleStage(),
stage,
method=highres.method,
overlap=params.overlap,
overlap=params.vae_overlap,
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
)
chain.stage(
BlendImg2ImgStage(),
stage,
overlap=params.overlap,
stage.with_args(outscale=1),
overlap=params.vae_overlap,
prompt_index=prompt_index + i,
strength=highres.strength,
)

View File

@ -1,33 +1,38 @@
from logging import getLogger
from typing import List
from typing import List, Optional
from PIL import Image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class PersistDiskStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
output: str,
stage_source: Image.Image,
output: List[str],
size: Optional[Size] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
for source in sources:
# TODO: append index to output name
dest = save_image(server, output, source, params=params)
) -> StageResult:
logger.info("persisting %s images to disk: %s", len(sources), output)
for source, name in zip(sources.as_image(), output):
dest = save_image(server, name, source, params=params, size=size)
logger.info("saved image to %s", dest)
return sources

View File

@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -20,26 +21,26 @@ class PersistS3Stage(BaseStage):
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
output: str,
output: List[str],
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
for source in sources:
for source, name in zip(sources.as_image(), output):
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
try:
s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output)
s3.upload_fileobj(data, bucket, name)
logger.info("saved image to s3://%s/%s", bucket, name)
except Exception:
logger.exception("error saving image to S3")

View File

@ -0,0 +1,281 @@
from datetime import timedelta
from logging import getLogger
from time import monotonic
from typing import Any, List, Optional, Tuple
from PIL import Image
from ..errors import CancelledException, RetryException
from ..output import save_image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
"""
Create a new pipeline that will run the given stages.
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
"""
if stage is not None:
self.stages.append(stage)
def run(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback],
**kwargs,
) -> List[Image.Image]:
result = self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
return result.as_image()
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
def steps(self, params: ImageParams, size: Size) -> int:
steps = 0
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)
return steps
def outputs(self, params: ImageParams, sources: int) -> int:
outputs = sources
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def __call__(
self,
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs,
) -> StageResult:
"""
DEPRECATED: use `.run()` instead
"""
if callback is None:
callback = worker.get_progress_callback()
else:
callback = ChainProgress.from_progress(callback)
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources),
kwargs.keys(),
)
per_stage_params = params
if "params" in kwargs:
per_stage_params = kwargs["params"]
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = has_mask(stage_kwargs) or any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources.as_image()
]
)
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if must_tile:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
tile,
)
def stage_tile(
source_tile: List[Image.Image],
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> List[Image.Image]:
for _i in range(worker.retries):
try:
tile_result = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
StageResult(images=source_tile),
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)
if is_debug():
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
return tile_result
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled while tiling")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline for tile, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
raise RetryException("exhausted retries on tile")
stage_results = process_tile_order(
stage_params.tile_order,
stage_sources,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_sources = StageResult(images=stage_results)
else:
logger.debug(
"image does not contain sources and is within tile size of %s, running stage",
tile,
)
for _i in range(worker.retries):
try:
stage_result = stage_pipe.run(
worker,
server,
stage_params,
per_stage_params,
stage_sources,
callback=callback,
dims=(0, 0, tile),
**kwargs,
)
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_result
break
except CancelledException as err:
worker.retries = 0
logger.exception("job was cancelled during stage")
raise err
except Exception:
worker.retries = worker.retries - 1
logger.exception(
"error while running stage pipeline, %s retries left",
worker.retries,
)
server.cache.clear()
run_gc([worker.get_device()])
if worker.retries <= 0:
raise RetryException("exhausted retries on stage")
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
)
if is_debug():
for j, image in enumerate(stage_sources.as_image()):
save_image(server, f"last-stage-{j}.png", image)
end = monotonic()
duration = timedelta(seconds=(end - start))
logger.info(
"finished pipeline in %s with %s results",
duration,
len(stage_sources),
)
return stage_sources
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
def has_mask(args: List[str]) -> bool:
return any([key in args for key in MASK_KEYS])

View File

@ -1,12 +1,13 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
origin: Size,
size: Size,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
outputs = []
for source in sources:
for source in sources.as_image():
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
outputs.append(image)
return outputs
return StageResult(images=outputs)

View File

@ -1,12 +1,12 @@
from logging import getLogger
from typing import List
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,15 +18,15 @@ class ReduceThumbnailStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
size: Size,
stage_source: Image.Image,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
outputs = []
for source in sources:
for source in sources.as_image():
image = source.copy()
image = image.thumbnail((size.width, size.height))
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
outputs.append(image)
return outputs
return StageResult(images=outputs)

View File

@ -0,0 +1,73 @@
from typing import List, Optional
import numpy as np
from PIL import Image
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
@staticmethod
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
self.arrays = arrays
self.images = images
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
elif self.images is not None:
return len(self.images)
else:
return 0
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")

View File

@ -1,12 +1,13 @@
from logging import getLogger
from typing import Callable, List
from typing import Callable, Optional
from PIL import Image
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,25 +19,34 @@ class SourceNoiseStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
size: Size,
noise_source: Callable,
stage_source: Image.Image,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
logger.info("generating image from noise source")
if len(sources) > 0:
logger.warning(
"source images were passed to a noise stage and will be discarded"
logger.info(
"source images were passed to a source stage, new images will be appended"
)
outputs = []
for source in sources:
# TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources.as_image():
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs
return StageResult(images=outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -20,18 +21,23 @@ class SourceS3Stage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
_sources: List[Image.Image],
sources: StageResult,
*,
source_keys: List[str],
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
outputs = []
if len(sources) > 0:
logger.info(
"source images were passed to a source stage, new images will be appended"
)
outputs = sources.as_image()
for key in source_keys:
try:
logger.info("loading image from s3://%s/%s", bucket, key)
@ -43,4 +49,11 @@ class SourceS3Stage(BaseStage):
except Exception:
logger.exception("error loading image from S3")
return outputs
return StageResult(outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1 # TODO: len(source_keys)

View File

@ -3,26 +3,28 @@ from typing import Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..constants import LATENT_FACTOR
from ..diffusers.load import load_pipeline
from ..diffusers.utils import (
encode_prompt,
get_latents_from_seed,
get_tile_latents,
parse_prompt,
parse_reseed,
slice_prompt,
)
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class SourceTxt2ImgStage(BaseStage):
max_tile = SizeChart.unlimited
max_tile = SizeChart.max
def run(
self,
@ -30,15 +32,15 @@ class SourceTxt2ImgStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
_source: Image.Image,
sources: StageResult,
*,
dims: Tuple[int, int, int],
dims: Tuple[int, int, int] = None,
size: Size,
callback: Optional[ProgressCallback] = None,
latents: Optional[np.ndarray] = None,
prompt_index: Optional[int] = None,
**kwargs,
) -> Image.Image:
) -> StageResult:
params = params.with_args(**kwargs)
size = size.with_args(**kwargs)
@ -47,31 +49,58 @@ class SourceTxt2ImgStage(BaseStage):
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
"generating image using txt2img, %s steps of %s: %s",
params.steps,
params.model,
params.prompt,
)
if "stage_source" in kwargs:
logger.warning(
"a source image was passed to a txt2img stage, and will be discarded"
if len(sources):
logger.info(
"source images were passed to a source stage, new images will be appended"
)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)
if params.is_xl():
tile_size = max(stage.tile_size, params.tiles)
if params.is_panorama() or params.is_xl():
tile_size = max(stage.tile_size, params.unet_tile)
else:
tile_size = params.tiles
tile_size = params.unet_tile
# this works for panorama as well, because tile_size is already max(tile_size, *size)
latent_size = size.min(tile_size, tile_size)
# generate new latents or slice existing
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
latents = get_latents_from_seed(int(params.seed), latent_size, params.batch)
else:
latents = get_tile_latents(latents, params.seed, latent_size, dims)
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
# reseed latents as needed
reseed_rng = np.random.RandomState(params.seed)
prompt, reseed = parse_reseed(prompt)
for top, left, bottom, right, region_seed in reseed:
if region_seed == -1:
region_seed = reseed_rng.random_integers(2**32 - 1)
logger.debug(
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
top,
left,
bottom,
right,
region_seed,
)
latents[
:,
:,
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
left // LATENT_FACTOR : right // LATENT_FACTOR,
] = get_latents_from_seed(
region_seed, Size(right - left, bottom - top), params.batch
)
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
@ -79,7 +108,7 @@ class SourceTxt2ImgStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)
@ -101,11 +130,14 @@ class SourceTxt2ImgStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
if params.is_panorama() or params.is_xl():
logger.debug(
"prompt alternatives are not supported for panorama or SDXL"
)
else:
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
if not params.is_xl():
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
@ -123,4 +155,21 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback,
)
return result.images
outputs = sources.as_image()
outputs.extend(result.images)
logger.debug("produced %s outputs", len(outputs))
return StageResult(images=outputs)
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return params.steps
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,6 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import List
from typing import List, Optional
import requests
from PIL import Image
@ -8,7 +8,8 @@ from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -20,20 +21,20 @@ class SourceURLStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
source_urls: List[str],
stage_source: Image.Image,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
logger.info("loading image from URL source")
if len(sources) > 0:
logger.warning(
"a source image was passed to a source stage, and will be discarded"
logger.info(
"source images were passed to a source stage, new images will be appended"
)
outputs = []
outputs = sources.as_image()
for url in source_urls:
response = requests.get(url)
output = Image.open(BytesIO(response.content))
@ -41,4 +42,11 @@ class SourceURLStage(BaseStage):
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs
return StageResult(images=outputs)
def outputs(
self,
params: ImageParams,
sources: int,
) -> int:
return sources + 1

View File

@ -1,31 +0,0 @@
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
raise NotImplementedError()
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()

View File

@ -0,0 +1,64 @@
from logging import getLogger
from .base import BaseStage
from .blend_denoise import BlendDenoiseStage
from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_simple import UpscaleSimpleStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__)
CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-simple": UpscaleSimpleStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}
def add_stage(name: str, stage: BaseStage) -> bool:
global CHAIN_STAGES
if name in CHAIN_STAGES:
logger.warning("cannot replace stage: %s", name)
return False
else:
CHAIN_STAGES[name] = stage
return True

View File

@ -2,13 +2,14 @@ import itertools
from enum import Enum
from logging import getLogger
from math import ceil
from typing import List, Optional, Protocol, Tuple
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
import numpy as np
from PIL import Image
from ..image.noise_source import noise_source_histogram
from ..params import Size, TileOrder
from .result import StageResult
# from skimage.exposure import match_histograms
@ -16,12 +17,15 @@ from ..params import Size, TileOrder
logger = getLogger(__name__)
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
class TileCallback(Protocol):
"""
Definition for a tile job function.
"""
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
"""
Run this stage against a single tile.
"""
@ -32,6 +36,9 @@ def complete_tile(
source: Image.Image,
tile: int,
) -> Image.Image:
"""
TODO: clean up
"""
if source is None:
return source
@ -50,6 +57,12 @@ def needs_tile(
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)
logger.trace(
"checking image tile dimensions: %s, %s, %s",
tile,
source.width > tile or source.height > tile if source is not None else False,
size.width > tile or size.height > tile if size is not None else False,
)
if source is not None:
return source.width > tile or source.height > tile
@ -60,7 +73,7 @@ def needs_tile(
return False
def get_tile_grads(
def make_tile_grads(
left: int,
top: int,
tile: int,
@ -85,6 +98,60 @@ def get_tile_grads(
return (grad_x, grad_y)
def make_tile_mask(
shape: Any,
tile: Tuple[int, int],
overlap: float,
edges: Tuple[bool, bool, bool, bool],
) -> np.ndarray:
mask = np.ones(shape)
tile_h, tile_w = tile
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
# sort gradient points
p1_h = adj_tile_h - 1
p2_h = tile_h - adj_tile_h
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
p1_w = adj_tile_w - 1
p2_w = tile_w - adj_tile_w
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
# build gradients
edge_t, edge_l, edge_b, edge_r = edges
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
int(not edge_t),
1,
1,
int(not edge_b),
]
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
mask = ((mask * mult_x).T * mult_y).T
return mask
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
if isinstance(image, np.ndarray):
return image.shape[-1]
if image.mode == "RGBA":
return 4
elif image.mode == "RGB":
return 3
elif image.mode == "L":
return 1
raise ValueError("unknown image format")
def blend_tiles(
tiles: List[Tuple[int, int, Image.Image]],
scale: int,
@ -98,23 +165,24 @@ def blend_tiles(
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
)
scaled_size = (height * scale, width * scale, 3)
channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
scaled_size = (height * scale, width * scale, channels)
count = np.zeros(scaled_size)
value = np.zeros(scaled_size)
for left, top, tile_image in tiles:
# histogram equalization
equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0])
if adj_tile < tile:
# sort gradient points
p1 = adj_tile * scale
p2 = (tile - adj_tile) * scale
points = [0, min(p1, p2), max(p1, p2), tile * scale]
p1 = (adj_tile * scale) - 1
p2 = (tile - adj_tile - 1) * scale
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
# gradient blending
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
@ -169,7 +237,7 @@ def blend_tiles(
margin_left : equalized.shape[1] + margin_right,
np.newaxis,
],
3,
channels,
axis=2,
)
@ -178,60 +246,18 @@ def blend_tiles(
return Image.fromarray(np.uint8(pixels))
def process_tile_grid(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.0,
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * adj_tile
top = y * adj_tile
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = (
source.crop((left, top, left + tile, top + tile)) if source else None
)
tile_image = complete_tile(tile_image, tile)
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
tiles.append((left, top, tile_image))
return blend_tiles(tiles, scale, width, height, tile, overlap)
def process_tile_spiral(
source: Image.Image,
def process_tile_stack(
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
tile_generator: TileGenerator,
overlap: float = 0.5,
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
) -> List[Image.Image]:
sources = stack.as_image()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
mask = kwargs.get("mask", None)
noise_source = kwargs.get("noise_source", noise_source_histogram)
fill_color = kwargs.get("fill_color", None)
@ -239,18 +265,10 @@ def process_tile_spiral(
tile_mask = None
tiles: List[Tuple[int, int, Image.Image]] = []
tile_coords = tile_generator(width, height, tile, overlap)
single_tile = len(tile_coords) == 1
# tile tuples is source, multiply by scale for dest
counter = 0
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
if len(tile_coords) == 1:
single_tile = True
else:
single_tile = False
for left, top in tile_coords:
counter += 1
for counter, (left, top) in enumerate(tile_coords):
logger.info(
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
)
@ -274,37 +292,33 @@ def process_tile_spiral(
needs_margin = True
bottom_margin = height - bottom
# if no source given, we don't have a source image
if not source:
tile_image = None
elif needs_margin:
# in the special case where the image is smaller than the specified tile size, just use the image
if single_tile:
logger.debug("creating and processing single-tile subtile")
tile_image = source
logger.debug("using single tile")
tile_stack = sources
if mask:
tile_mask = mask
# otherwise use add histogram noise outside of the image border
else:
elif needs_margin:
logger.debug(
"tiling and adding margins: %s, %s, %s, %s",
"tiling with added margins: %s, %s, %s, %s",
left_margin,
top_margin,
right_margin,
bottom_margin,
)
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
tile_stack = add_margin(
stack.as_image(),
left,
top,
right,
bottom,
left_margin,
top_margin,
right_margin,
bottom_margin,
tile,
noise_source,
fill_color,
)
)
tile_image = noise_source(
base_image, (tile, tile), (0, 0), fill=fill_color
)
tile_image.paste(base_image, (left_margin, top_margin))
if mask:
base_mask = mask.crop(
@ -320,38 +334,55 @@ def process_tile_spiral(
else:
logger.debug("tiling normally")
tile_image = source.crop((left, top, right, bottom))
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
if mask:
tile_mask = mask.crop((left, top, right, bottom))
for image_filter in filters:
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
tiles.append((left, top, tile_image))
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
if single_tile:
return tile_image
else:
return blend_tiles(tiles, scale, width, height, tile, overlap)
tiles.append((left, top, tile_stack.as_image()))
lefts, tops, stacks = list(zip(*tiles))
coords = list(zip(lefts, tops))
stacks = list(zip(*stacks))
result = []
for stack in stacks:
stack_tiles = zip(coords, stack)
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
return result
def process_tile_order(
order: TileOrder,
source: Image.Image,
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
) -> Image.Image:
"""
TODO: needs to handle more than one image
"""
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs)
return process_tile_stack(
stack, tile, scale, filters, generate_tile_grid, **kwargs
)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)
return process_tile_stack(
stack, tile, scale, filters, generate_tile_spiral, **kwargs
)
else:
logger.warning("unknown tile order: %s", order)
raise ValueError()
@ -445,3 +476,77 @@ def generate_tile_spiral(
height_tile_target -= abs(state.value[1])
return tile_coords
def generate_tile_grid(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
) -> List[Tuple[int, int]]:
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
left = x * adj_tile
top = y * adj_tile
tiles.append((int(left), int(top)))
return tiles
def get_result_tile(
result: StageResult,
origin: Tuple[int, int],
tile: Size,
) -> List[Image.Image]:
top, left = origin
return [
layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image()
]
def add_margin(
stack: List[Image.Image],
left: int,
top: int,
right: int,
bottom: int,
left_margin: int,
top_margin: int,
right_margin: int,
bottom_margin: int,
tile: int,
noise_source,
fill_color,
) -> List[Image.Image]:
results = []
for source in stack:
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin))
results.append(tile_image)
return results

View File

@ -1,22 +1,30 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..params import (
DeviceParams,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class UpscaleBSRGANStage(BaseStage):
max_tile = 64
max_tile = SizeChart.micro
def load(
self,
@ -54,12 +62,12 @@ class UpscaleBSRGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None:
@ -71,40 +79,38 @@ class UpscaleBSRGANStage(BaseStage):
bsrgan = self.load(server, stage, upscale, device)
outputs = []
for source in sources:
image = np.array(source) / 255.0
for source in sources.as_numpy():
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
logger.trace(
"BSRGAN output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
),
)
)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest = bsrgan(image)
output = bsrgan(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
output = np.clip(np.squeeze(output, axis=0), 0, 1)
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
logger.debug("output image shape: %s", output.shape)
outputs.append(output)
return outputs
return StageResult(arrays=outputs)
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
tile = min(params.tiles, self.max_tile)
tile = min(params.unet_tile, self.max_tile)
return size.width // tile * size.height // tile

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
@ -8,7 +8,8 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
*args,
sources: StageResult,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
if highres.scale <= 1:
return sources
chain = stage_highres(stage, params, highres, upscale)
return [
outputs = [
chain(
worker,
server,
@ -41,5 +42,7 @@ class UpscaleHighresStage(BaseStage):
source,
callback=callback,
)
for source in sources
for source in sources.as_image()
]
return StageResult(images=outputs)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional, Tuple
import numpy as np
import torch
@ -18,13 +18,14 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class UpscaleOutpaintStage(BaseStage):
max_tile = SizeChart.unlimited
max_tile = SizeChart.max
def run(
self,
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
border: Border,
dims: Tuple[int, int, int],
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)
@ -56,12 +57,12 @@ class UpscaleOutpaintStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)
outputs = []
for source in sources:
for source in sources.as_image():
if is_debug():
save_image(server, "tile-source.png", source)
save_image(server, "tile-mask.png", tile_mask)
@ -71,7 +72,7 @@ class UpscaleOutpaintStage(BaseStage):
outputs.append(source)
continue
tile_size = params.tiles
tile_size = params.unet_tile
size = Size(*source.size)
latent_size = size.min(tile_size, tile_size)
@ -99,6 +100,7 @@ class UpscaleOutpaintStage(BaseStage):
)
else:
# encode and record alternative prompts outside of LPW
if not params.is_xl():
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
@ -121,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
outputs.extend(result.images)
return outputs
return StageResult(images=outputs)

View File

@ -1,8 +1,7 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..onnx import OnnxRRDBNet
@ -10,7 +9,8 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -77,25 +77,22 @@ class UpscaleRealESRGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
outputs = []
for source in sources:
output = np.array(source)
upsampler = self.load(
server, upscale, worker.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
outputs = []
for source in sources.as_numpy():
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
logger.info("final output image size: %s", output.shape)
outputs.append(output)
return outputs
return StageResult(arrays=outputs)

View File

@ -1,12 +1,13 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
method: str,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
if upscale.scale <= 1:
logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale
@ -32,18 +33,20 @@ class UpscaleSimpleStage(BaseStage):
return sources
outputs = []
for source in sources:
for source in sources.as_image():
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
outputs.append(
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
outputs.append(
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
)
else:
logger.warning("unknown upscaling method: %s", method)
outputs.append(source)
return outputs
return StageResult(images=outputs)

View File

@ -1,8 +1,8 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import torch
import numpy as np
from PIL import Image
from ..diffusers.load import load_pipeline
@ -10,7 +10,8 @@ from ..diffusers.utils import encode_prompt, parse_prompt
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
server: ServerContext,
_stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs)
logger.info(
@ -46,8 +47,9 @@ class UpscaleStableDiffusionStage(BaseStage):
worker.get_device(),
model=path.join(server.model_path, upscale.upscale_model),
)
generator = torch.manual_seed(params.seed)
rng = np.random.RandomState(params.seed)
if not params.is_xl():
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
@ -57,11 +59,11 @@ class UpscaleStableDiffusionStage(BaseStage):
pipeline.unet.set_prompts(prompt_embeds)
outputs = []
for source in sources:
for source in sources.as_image():
result = pipeline(
prompt,
source,
generator=generator,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
@ -71,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
)
outputs.extend(result.images)
return outputs
return StageResult(images=outputs)

View File

@ -1,22 +1,23 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .stage import BaseStage
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class UpscaleSwinIRStage(BaseStage):
max_tile = 64
max_tile = SizeChart.micro
def load(
self,
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
) -> StageResult:
upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None:
@ -71,31 +72,30 @@ class UpscaleSwinIRStage(BaseStage):
swinir = self.load(server, stage, upscale, device)
outputs = []
for source in sources:
for source in sources.as_numpy():
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
logger.trace(
"SwinIR output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
),
)
)
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = swinir(image)
output = np.clip(np.squeeze(output, axis=0), 0, 1)
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
logger.info("output image size: %s", output.shape)
outputs.append(output)
return outputs
return StageResult(images=outputs)

View File

@ -1,2 +1,5 @@
ONNX_MODEL = "model.onnx"
ONNX_WEIGHTS = "weights.pb"
LATENT_FACTOR = 8
LATENT_CHANNELS = 4

View File

@ -15,7 +15,8 @@ from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.diffusion import convert_diffusion_diffusers
from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
from .diffusion.lora import blend_loras
from .diffusion.textual_inversion import blend_textual_inversions
from .upscaling.bsrgan import convert_upscaling_bsrgan
@ -357,6 +358,16 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion, name, model["source"], format=model_format
)
pipeline = model.get("pipeline", "txt2img")
if pipeline.endswith("-sdxl"):
converted, dest = convert_diffusion_diffusers_xl(
conversion,
model,
source,
model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
@ -588,7 +599,7 @@ def main(args=None) -> int:
logger.info("CLI arguments: %s", args)
server = ConversionContext.from_environ()
server.half = args.half or "onnx-fp16" in server.optimizations
server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset
server.token = args.token
logger.info(

View File

@ -20,7 +20,6 @@ from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
@ -32,17 +31,25 @@ from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
from ..utils import (
RESOLVE_FORMATS,
ConversionContext,
check_ext,
is_torch_2_0,
load_tensor,
onnx_export,
)
from .checkpoint import convert_extract_checkpoint
logger = getLogger(__name__)
available_pipelines = {
"controlnet": StableDiffusionControlNetPipeline,
CONVERT_PIPELINES = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
@ -96,7 +103,6 @@ def get_model_version(
opts["prediction_type"] = "epsilon"
except Exception:
logger.debug("unable to load tensor for version check")
pass
return (v2, opts)
@ -314,7 +320,7 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)
@ -360,7 +366,6 @@ def convert_diffusion_diffusers(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
vae_path=replace_vae,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
@ -374,6 +379,17 @@ def convert_diffusion_diffusers(
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.vae.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path)
@ -424,9 +440,6 @@ def convert_diffusion_diffusers(
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL
@ -526,19 +539,6 @@ def convert_diffusion_diffusers(
del unet
run_gc()
# VAE
if replace_vae is not None:
if replace_vae.startswith("."):
logger.debug(
"custom VAE appears to be a local path, making it relative to the model path"
)
replace_vae = path.join(conversion.model_path, replace_vae)
logger.info("loading custom VAE: %s", replace_vae)
vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.vae = vae
run_gc()
if single_vae:
logger.debug("VAE config: %s", pipeline.vae.config)

View File

@ -0,0 +1,116 @@
from logging import getLogger
from os import path
from typing import Dict, Optional, Tuple
import onnx
import torch
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_diffusers_xl(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
replace_vae = model.get("vae", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_hash = path.join(dest_path, "hash.txt")
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
)
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])
return (False, dest_path)
# safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch")
if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file(
source, use_safetensors=True
)
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
logger.debug("loading VAE from single tensor file: %s", vae_path)
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
logger.debug("loading pretrained VAE from path: %s", vae_path)
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
if path.exists(temp_path):
logger.debug("torch model already exists for %s: %s", source, temp_path)
else:
logger.debug("exporting torch model for %s: %s", source, temp_path)
pipeline.save_pretrained(temp_path)
# directory -> onnx using optimum exporters
main_export(
temp_path,
output=dest_path,
task="stable-diffusion-xl",
device=device,
fp16=conversion.has_optimization(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
)
if "hash" in model:
logger.debug("adding hash file to ONNX model")
with open(model_hash, "w") as f:
f.write(model["hash"])
if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)
unet = onnx.load(unet_path)
opt_model = convert_float_to_float16(
unet,
disable_shape_infer=True,
force_fp16_initializers=True,
keep_io_types=True,
op_block_list=["Attention", "MultiHeadAttention"],
)
onnx.save_model(
opt_model,
unet_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)
return False, dest_path

View File

@ -1,22 +1,15 @@
from argparse import ArgumentParser
from logging import getLogger
from os import path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from onnx import ModelProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import OrtValue
from scipy import interpolate
from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor
from ..utils import load_tensor
logger = getLogger(__name__)
@ -39,7 +32,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
lr = a
if kernel == (1, 1):
lr = np.expand_dims(lr, axis=(2, 3))
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
return hr + lr
@ -78,13 +71,15 @@ def fix_node_name(key: str):
return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
fixed = {}
names = [fix_node_name(node.name) for node in nodes]
for key, value in keys.items():
root, *rest = key.split(".")
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
root, *_rest = key.split(".")
logger.trace("fixing XL node name: %s -> %s", key, root)
simple = False
if root.startswith("input"):
block = "down_blocks"
elif root.startswith("middle"):
@ -93,6 +88,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
block = "up_blocks"
elif root.startswith("text_model"):
block = "text_model"
elif root.startswith("down_blocks"):
block = "down_blocks"
simple = True
elif root.startswith("mid_block"):
block = "mid_block"
simple = True
elif root.startswith("up_blocks"):
block = "up_blocks"
simple = True
else:
logger.warning("unknown XL key name: %s", key)
fixed[key] = value
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
suffix = None
for s in [
"conv",
"conv_shortcut",
"conv1",
"conv2",
"fc1",
"fc2",
"ff_net_0_proj",
@ -119,18 +127,21 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
logger.warning("new XL key type: %s", root)
continue
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
match = None
if block == "text_model":
match = next(
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
)
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
match: Optional[str] = None
if "conv" in suffix:
match = next(node for node in names if node == f"{root}_Conv")
elif "time_emb_proj" in root:
match = next(node for node in names if node == f"{root}_Gemm")
elif block == "text_model" or simple:
match = next(node for node in names if node == f"{root}_MatMul")
else:
# search in order. one side has sparse indices, so they will not match.
match = next(
node
for node in nodes
if node.name.startswith(f"/{block}")
and fix_node_name(node.name).endswith(
for node in names
if node.startswith(block)
and node.endswith(
f"{suffix}_MatMul"
) # needs to be fixed because some places use to_out.0
)
@ -138,18 +149,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
if match is None:
logger.warning("no matches for XL key: %s", root)
continue
else:
logger.trace("matched key: %s -> %s", key, match)
name: str = match.name
name = fix_node_name(name.rstrip("/MatMul"))
name = match
if name.endswith("_MatMul"):
name = name[:-7]
elif name.endswith("_Gemm"):
name = name[:-5]
elif name.endswith("_Conv"):
name = name[:-5]
if name.endswith("proj_o"):
# wtf
name = f"{name}ut"
logger.debug("matching XL key with node: %s -> %s", key, match.name)
logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
fixed[name] = value
nodes.remove(match)
names.remove(match)
logger.debug(
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
len(fixed.keys()),
len(keys.keys()),
len(names),
)
return fixed
@ -161,39 +182,9 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = torch.float32
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else:
lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
def blend_weights_loha(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
t1_key = key.replace("hada_w1_a", "hada_t1")
@ -260,26 +251,18 @@ def blend_loras(
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
np_weights = weights.numpy() * (alpha / dim)
np_weights *= lora_weight
if base_key in blended:
logger.trace(
"summing LoHA weights: %s + %s",
blended[base_key].shape,
np_weights.shape,
)
blended[base_key] += sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
return base_key, np_weights
def blend_weights_lora(
key: str, lora_prefix: str, lora_model: Dict, dtype
) -> Tuple[str, np.ndarray]:
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
mid_key = key.replace("lora_down", "lora_mid")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.trace(
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
)
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
down_weight = lora_model[key].to(dtype=dtype)
up_weight = lora_model[up_key].to(dtype=dtype)
@ -320,10 +303,7 @@ def blend_loras(
alpha,
)
weights = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
)
@ -341,15 +321,12 @@ def blend_loras(
mid_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
weights[:, :, w, h] = (
up_weight.squeeze(3).squeeze(2)
@ mid_weight[:, :, w, h]
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
) @ down_weight.squeeze(3).squeeze(2)
np_weights = weights.numpy() * (alpha / dim)
@ -361,9 +338,7 @@ def blend_loras(
up_weight.shape,
alpha,
)
weights = torch.zeros(
(up_weight.shape[0], down_weight.shape[1], *kernel)
)
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
for w in range(kernel[0]):
for h in range(kernel[1]):
@ -371,8 +346,7 @@ def blend_loras(
up_w, up_h = kernel_slice(w, h, up_weight.shape)
weights[:, :, w, h] = (
up_weight[:, :, up_w, up_h]
@ down_weight[:, :, down_w, down_h]
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
)
np_weights = weights.numpy() * (alpha / dim)
@ -382,48 +356,165 @@ def blend_loras(
base_key,
up_weight.shape[-2:],
)
# TODO: should this be None?
np_weights = np.zeros((1, 1, 1, 1))
return base_key, np_weights
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
t_weights = weights.transpose()
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = torch.float32
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else:
lora_prefix = f"lora_{model_type}_"
layers = []
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
np_weights *= lora_weight
if base_key in blended:
blended: Dict[str, np.ndarray] = {}
layers.append(blended)
for key in lora_model.keys():
if ".hada_w1_a" in key and lora_prefix in key:
# LoHA
base_key, np_weights = blend_weights_loha(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
logger.trace(
"summing weights: %s + %s",
blended[base_key].shape,
"adding LoHA weights: %s",
np_weights.shape,
)
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
# LoRA or LoCON
base_key, np_weights = blend_weights_lora(
key, lora_prefix, lora_model, dtype
)
np_weights = np_weights * lora_weight
logger.trace(
"adding LoRA weights: %s",
np_weights.shape,
)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
# rewrite node names for XL
# rewrite node names for XL and flatten layers
weights: Dict[str, np.ndarray] = {}
for blended in layers:
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace(
"updating %s of %s initializers",
len(blended.keys()),
len(base_model.graph.initializer),
)
for key, value in blended.items():
if key in weights:
weights[key] = sum_weights(weights[key], value)
else:
weights[key] = value
# fix node names once
fixed_initializer_names = [
fix_initializer_name(node.name) for node in base_model.graph.initializer
]
logger.trace("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
logger.trace("fixed node names: %s", fixed_node_names)
logger.debug(
"updating %s of %s initializers",
len(weights.keys()),
len(base_model.graph.initializer),
)
unmatched_keys = []
for base_key, weights in blended.items():
for base_key, weights in weights.items():
conv_key = base_key + "_Conv"
gemm_key = base_key + "_Gemm"
matmul_key = base_key + "_MatMul"
logger.trace(
"key %s has conv: %s, matmul: %s",
"key %s has conv: %s, gemm: %s, matmul: %s",
base_key,
conv_key in fixed_node_names,
gemm_key in fixed_node_names,
matmul_key in fixed_node_names,
)
@ -449,38 +540,9 @@ def blend_loras(
weight_node = base_model.graph.initializer[weight_idx]
logger.trace("found weight initializer: %s", weight_node.name)
# blending
onnx_weights = numpy_helper.to_array(weight_node)
logger.trace(
"found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape,
)
# replace the previous node
updated_node = blend_node_conv_gemm(weight_node, weights)
if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3))
else:
if onnx_weights.shape != weights.shape:
logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else:
blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), weight_node.name
)
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node)
elif matmul_key in fixed_node_names:
@ -497,42 +559,15 @@ def blend_loras(
matmul_node = base_model.graph.initializer[matmul_idx]
logger.trace("found matmul initializer: %s", matmul_node.name)
# blending
onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace(
"found blended weights for matmul: %s, %s",
weights.shape,
onnx_weights.shape,
)
# replace the previous node
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
t_weights = weights.transpose()
if (
weights.shape != onnx_weights.shape
and t_weights.shape != onnx_weights.shape
):
logger.warning(
"weight shapes do not match for %s: %s vs %s",
matmul_key,
weights.shape,
onnx_weights.shape,
)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.debug(
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
)
# replace the original initializer
updated_node = numpy_helper.from_array(
blended.astype(onnx_weights.dtype), matmul_node.name
)
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node)
else:
unmatched_keys.append(base_key)
logger.debug(
logger.trace(
"node counts: %s -> %s, %s -> %s",
len(fixed_initializer_names),
len(base_model.graph.initializer),
@ -541,10 +576,7 @@ def blend_loras(
)
if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
# if model_type == "unet":
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
return base_model
@ -568,63 +600,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
logger.debug("weights after interpolation: %s", output.shape)
return output
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -14,55 +14,23 @@ from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, inversion_format in inversions:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if inversion_format is None:
def detect_embedding_format(loaded_embeds) -> str:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept"
return "concept"
elif "emb_params" in keys:
logger.debug(
"detected Textual Inversion parameter embeddings: %s", keys
)
inversion_format = "parameters"
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
return "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings"
return "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
return None
if inversion_format == "concept":
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds
token = list(loaded_embeds.keys())[0]
@ -78,11 +46,13 @@ def blend_textual_inversions(
embeds[token] += layer
else:
embeds[token] = layer
elif inversion_format == "parameters":
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(emb_params[0, :].shape)
@ -108,7 +78,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
elif inversion_format == "embeddings":
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
@ -117,7 +89,7 @@ def blend_textual_inversions(
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(trained_embeds[0, :].shape)
@ -143,23 +115,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
else:
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
# add the tokens to the tokenizer
logger.debug(
"found embeddings for %s tokens: %s",
len(embeds.keys()),
list(embeds.keys()),
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
@ -191,6 +149,59 @@ def blend_textual_inversions(
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, format in embeddings:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if format is None:
format = detect_embedding_format(loaded_embeds)
if format == "concept":
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "embeddings":
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
else:
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
return (text_encoder, tokenizer)

View File

@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
class ConversionContext(ServerContext):
def __init__(
self,
model_path: Optional[str] = None,
model_path: str = ".",
cache_path: Optional[str] = None,
device: Optional[str] = None,
half: bool = False,
@ -69,7 +69,7 @@ class ConversionContext(ServerContext):
def from_environ(cls):
context = super().from_environ()
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True)
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
name, source, *_rest = model
return {
"name": name,
@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
scale = rest.pop(0) if len(rest) > 0 else 1
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None
return {
"name": name,
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
single_vae = rest.pop(0) if len(rest) > 0 else False
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None
return {
"name": name,
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None
scale = rest.pop(0) if len(rest) > 0 else 1
half = rest.pop(0) if len(rest) > 0 else False
opset = rest.pop(0) if len(rest) > 0 else None
return {
"name": name,
@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
_name, ext = path.splitext(name)
ext = ext.strip(".")
return (ext in exts, ext)
def source_format(model: Dict) -> Optional[str]:
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
return model["format"]
if "source" in model:
_name, ext = path.splitext(model["source"])
if ext in MODEL_FORMATS:
valid, ext = check_ext(model["source"], MODEL_FORMATS)
if valid:
return ext
return None
@ -298,6 +305,7 @@ def onnx_export(
half=False,
external_data=False,
v2=False,
op_block_list=None,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@ -316,8 +324,7 @@ def onnx_export(
opset_version=opset,
)
op_block_list = None
if v2:
if v2 and op_block_list is None:
op_block_list = ["Attention", "MultiHeadAttention"]
if half:

View File

@ -1,16 +1,15 @@
from logging import getLogger
from os import path
from typing import Any, List, Optional, Tuple
from typing import Any, List, Literal, Optional, Tuple
from onnx import load_model
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL
from ..constants import LATENT_FACTOR, ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
@ -24,6 +23,7 @@ from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from .version_safe_diffusers import (
DDIMScheduler,
@ -38,6 +38,7 @@ from .version_safe_diffusers import (
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
LMSDiscreteScheduler,
OnnxRuntimeModel,
OnnxStableDiffusionImg2ImgPipeline,
@ -58,6 +59,7 @@ available_pipelines = {
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
"panorama": OnnxStableDiffusionPanoramaPipeline,
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
"txt2img": OnnxStableDiffusionPipeline,
@ -77,12 +79,25 @@ pipeline_schedulers = {
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lcm": LCMScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,
"unipc-multi": UniPCMultistepScheduler,
}
def add_pipeline(name: str, pipeline: Any) -> bool:
global available_pipelines
if name in available_pipelines:
# TODO: decide if this should be allowed or not
logger.warning("cannot replace existing pipeline: %s", name)
return False
else:
available_pipelines[name] = pipeline
return True
def get_available_pipelines() -> List[str]:
return list(available_pipelines.keys())
@ -99,16 +114,19 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
return None
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
def load_pipeline(
server: ServerContext,
params: ImageParams,
pipeline: str,
device: DeviceParams,
inversions: Optional[List[Tuple[str, float]]] = None,
embeddings: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
model: Optional[str] = None,
):
inversions = inversions or []
embeddings = embeddings or []
loras = loras or []
model = model or params.model
@ -122,7 +140,7 @@ def load_pipeline(
device.device,
device.provider,
control_key,
inversions,
embeddings,
loras,
)
scheduler_key = (params.scheduler, model)
@ -159,26 +177,128 @@ def load_pipeline(
run_gc([device])
logger.debug("loading new diffusion pipeline from %s", model)
components = {
"scheduler": scheduler_type.from_pretrained(
scheduler = scheduler_type.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
components = {
"scheduler": scheduler,
}
# shared components
text_encoder = None
unet_type = "unet"
# ControlNet component
if params.is_control() and params.control is not None:
cnet_path = path.join(
server.model_path, "control", f"{params.control.name}.onnx"
logger.debug("loading ControlNet components")
control_components = load_controlnet(server, device, params)
components.update(control_components)
unet_type = "cnet"
# load various pipeline components
encoder_components = load_text_encoders(
server, device, model, embeddings, loras, torch_dtype, params
)
components.update(encoder_components)
unet_components = load_unet(server, device, model, loras, unet_type, params)
components.update(unet_components)
vae_components = load_vae(server, device, model, params)
components.update(vae_components)
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
if params.is_xl():
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class(
components["vae_decoder_session"],
components["text_encoder_session"],
components["unet_session"],
{
"force_zeros_for_empty_prompt": True,
"requires_aesthetics_score": False,
},
components["tokenizer"],
scheduler,
vae_encoder_session=components.get("vae_encoder_session", None),
text_encoder_2_session=components.get("text_encoder_2_session", None),
tokenizer_2=components.get("tokenizer_2", None),
)
else:
if "vae" in components:
# upscale uses a single VAE
logger.debug(
"assembling SD pipeline for %s with single VAE",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae"],
components["text_encoder"],
components["tokenizer"],
components["unet"],
scheduler,
scheduler,
)
else:
logger.debug(
"assembling SD pipeline for %s with VAE codec",
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae_encoder"],
components["vae_decoder"],
components["text_encoder"],
components["tokenizer"],
components["unet"],
scheduler,
None,
None,
requires_safety_checker=False,
)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
patch_pipeline(server, pipe, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
vae_model = getattr(pipe, vae)
if isinstance(vae_model, VAEWrapper):
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
# update panorama params
if params.is_panorama():
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
logger.debug(
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
params.unet_tile,
unet_stride,
params.vae_tile,
params.vae_overlap,
)
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
run_gc([device])
return pipe
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path)
components = {}
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
cnet_path,
@ -186,63 +306,89 @@ def load_pipeline(
sess_options=device.sess_options(),
)
)
return components
unet_type = "cnet"
# Textual Inversion blending
if inversions is not None and len(inversions) > 0:
logger.debug("blending Textual Inversions from %s", inversions)
inversion_names, inversion_weights = zip(*inversions)
inversion_models = [
path.join(server.model_path, "inversion", name)
for name in inversion_names
]
def load_text_encoders(
server: ServerContext,
device: DeviceParams,
model: str,
embeddings: Optional[List[Tuple[str, float]]],
loras: Optional[List[Tuple[str, float]]],
torch_dtype,
params: ImageParams,
):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
components = {
"tokenizer": tokenizer,
}
if params.is_xl():
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
tokenizer_2 = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer_2",
torch_dtype=torch_dtype,
)
components["tokenizer_2"] = tokenizer_2
# blend embeddings, if any
if embeddings is not None and len(embeddings) > 0:
embedding_names, embedding_weights = zip(*embeddings)
embedding_models = [
path.join(server.model_path, "inversion", name) for name in embedding_names
]
logger.debug(
"blending base model %s with embeddings from %s", model, embedding_models
)
# TODO: blend text_encoder_2 as well
text_encoder, tokenizer = blend_textual_inversions(
server,
text_encoder,
tokenizer,
list(
zip(
inversion_models,
inversion_weights,
inversion_names,
[None] * len(inversion_models),
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
components["tokenizer"] = tokenizer
# should be pretty small and should not need external data
if loras is None or len(loras) == 0:
# TODO: handle XL encoders
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
if params.is_xl():
text_encoder_2, tokenizer_2 = blend_textual_inversions(
server,
text_encoder_2,
tokenizer_2,
list(
zip(
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
components["tokenizer_2"] = tokenizer_2
# LoRA blending
# blend LoRAs, if any
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
# blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
text_encoder = blend_loras(
server,
text_encoder,
@ -251,34 +397,8 @@ def load_pipeline(
1 if params.is_xl() else None,
params.is_xl(),
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
if params.is_xl():
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras(
server,
text_encoder_2,
@ -287,6 +407,17 @@ def load_pipeline(
2,
params.is_xl(),
)
# prepare external data for sessions
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
# encoder 2 only exists in XL
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
@ -296,6 +427,16 @@ def load_pipeline(
list(text_encoder_2_names), list(text_encoder_2_values)
)
# session for te1
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
# session for te2
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
@ -303,17 +444,48 @@ def load_pipeline(
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
else:
# session for te
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
return components
def load_unet(
server: ServerContext,
device: DeviceParams,
model: str,
loras: List[Tuple[str, float]],
unet_type: Literal["cnet", "unet"],
params: ImageParams,
):
components = {}
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
# LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load unet
unet = path.join(model, unet_type, ONNX_MODEL)
blended_unet = blend_loras(
unet = blend_loras(
server,
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
(unet_model, unet_data) = buffer_external_data_tensors(unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
@ -335,23 +507,18 @@ def load_pipeline(
)
)
# make sure a UNet has been loaded
if not params.is_xl() and "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider("unet"),
sess_options=device.sess_options(),
)
)
return components
def load_vae(
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
):
# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
components = {}
if not params.is_xl() and path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
@ -361,9 +528,25 @@ def load_pipeline(
sess_options=device.sess_options(),
)
)
elif (
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
):
elif path.exists(vae_decoder) and path.exists(vae_encoder):
if params.is_xl():
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_decoder_session"]._model_path = vae_decoder
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
components["vae_encoder_session"]._model_path = vae_encoder
else:
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -382,119 +565,44 @@ def load_pipeline(
)
)
# additional options for panorama pipeline
if params.is_panorama():
components["window"] = params.tiles // 8
components["stride"] = params.stride // 8
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
safety_checker=None,
torch_dtype=torch_dtype,
**components,
)
# make sure XL models are actually being used
# TODO: why is this needed?
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder.session == components["text_encoder_session"],
type(pipe.text_encoder),
)
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
if "text_encoder_2_session" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2.session == components["text_encoder_2_session"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = ORTModelTextEncoder(
text_encoder_2_session, text_encoder_2
)
if "unet_session" in components:
logger.info(
"unet matches: %s, %s",
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
pipe.unet = ORTModelUnet(unet_session, unet_model)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
optimize_pipeline(server, pipe)
if not params.is_xl():
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
# update panorama params
if params.is_panorama():
latent_window = params.tiles // 8
latent_stride = params.stride // 8
pipe.set_window_size(latent_window, latent_stride)
if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_window_size(latent_window, params.overlap)
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
run_gc([device])
return pipe
return components
def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if (
"diffusers-attention-slicing" in server.optimizations
or "diffusers-attention-slicing-auto" in server.optimizations
):
if server.has_optimization(
"diffusers-attention-slicing"
) or server.has_optimization("diffusers-attention-slicing-auto"):
logger.debug("enabling auto attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing(slice_size="auto")
except Exception as e:
logger.warning("error while enabling auto attention slicing: %s", e)
if "diffusers-attention-slicing-max" in server.optimizations:
if server.has_optimization("diffusers-attention-slicing-max"):
logger.debug("enabling max attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing(slice_size="max")
except Exception as e:
logger.warning("error while enabling max attention slicing: %s", e)
if "diffusers-vae-slicing" in server.optimizations:
if server.has_optimization("diffusers-vae-slicing"):
logger.debug("enabling VAE slicing on SD pipeline")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.warning("error while enabling VAE slicing: %s", e)
if "diffusers-cpu-offload-sequential" in server.optimizations:
if server.has_optimization("diffusers-cpu-offload-sequential"):
logger.debug("enabling sequential CPU offload on SD pipeline")
try:
pipe.enable_sequential_cpu_offload()
except Exception as e:
logger.warning("error while enabling sequential CPU offload: %s", e)
elif "diffusers-cpu-offload-model" in server.optimizations:
elif server.has_optimization("diffusers-cpu-offload-model"):
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
try:
@ -502,7 +610,7 @@ def optimize_pipeline(
except Exception as e:
logger.warning("error while enabling model CPU offload: %s", e)
if "diffusers-memory-efficient-attention" in server.optimizations:
if server.has_optimization("diffusers-memory-efficient-attention"):
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
try:
@ -514,17 +622,17 @@ def optimize_pipeline(
def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
pipe_type: str,
pipeline: Any,
params: ImageParams,
) -> None:
logger.debug("patching SD pipeline")
if pipe_type != "lpw":
if not params.is_lpw() and not params.is_xl():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet)
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
logger.debug("patched UNet with wrapper")
if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder
@ -532,18 +640,21 @@ def patch_pipeline(
server,
original_decoder,
decoder=True,
window=params.tiles,
overlap=params.overlap,
window=params.unet_tile,
overlap=params.vae_overlap,
)
logger.debug("patched VAE decoder with wrapper")
if hasattr(pipe, "vae_encoder"):
original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper(
server,
original_encoder,
decoder=False,
window=params.tiles,
overlap=params.overlap,
window=params.unet_tile,
overlap=params.vae_overlap,
)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:
logger.debug("no VAE found to patch")
logger.debug("patched VAE encoder with wrapper")
if hasattr(pipe, "vae"):
logger.warning("not patching single VAE, tiled VAE may not work")

View File

@ -14,20 +14,23 @@ class UNetWrapper(object):
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
xl: bool
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
xl: bool,
):
self.server = server
self.wrapped = wrapped
self.xl = xl
def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
sample: Optional[np.ndarray] = None,
timestep: Optional[np.ndarray] = None,
encoder_hidden_states: Optional[np.ndarray] = None,
**kwargs,
):
logger.trace(
@ -43,6 +46,14 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
if self.xl:
if sample.dtype != encoder_hidden_states.dtype:
logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
else:
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

View File

@ -12,8 +12,6 @@ from ...server import ServerContext
logger = getLogger(__name__)
LATENT_CHANNELS = 4
class VAEWrapper(object):
def __init__(
@ -39,11 +37,17 @@ class VAEWrapper(object):
self.tile_overlap_factor = overlap
def __call__(self, latent_sample=None, sample=None, **kwargs):
model = (
self.wrapped.model
if hasattr(self.wrapped, "model")
else self.wrapped.session
)
# set timestep dtype to input type
sample_dtype = next(
(
input.type
for input in self.wrapped.model.get_inputs()
for input in model.get_inputs()
if input.name == "sample" or input.name == "latent_sample"
),
"tensor(float)",

View File

@ -13,8 +13,8 @@ import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging

View File

@ -13,25 +13,36 @@
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
from math import ceil
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer
from ...chain.tile import make_tile_mask
from ...constants import LATENT_CHANNELS, LATENT_FACTOR
from ...params import Size
from ..utils import (
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.get_logger(__name__)
# inpaint constants
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4
DEFAULT_WINDOW = 32
DEFAULT_STRIDE = 8
@ -346,13 +357,16 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
def get_views(self, panorama_height, panorama_width, window_size, stride):
def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug(
"panorama generated %s views, %s by %s blocks",
@ -369,7 +383,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return views
return (views, (h_end * 8, w_end * 8))
@torch.no_grad()
def text2img(
@ -479,6 +493,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
@ -488,9 +504,30 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
region_prompt_embeds = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
region_embeds.append(region_prompt_embeds)
# get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
LATENT_CHANNELS,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape:
@ -525,11 +562,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0)
value.fill(0)
@ -576,13 +624,115 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = (
region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 100.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
@ -828,9 +978,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(timesteps)):
count.fill(0)
@ -886,6 +1046,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
@ -1053,12 +1218,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
num_channels_latents = NUM_LATENT_CHANNELS
num_channels_latents = LATENT_CHANNELS
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,
height // 8,
width // 8,
height // LATENT_FACTOR,
width // LATENT_FACTOR,
)
latents_dtype = prompt_embeds.dtype
if latents is None:
@ -1136,9 +1301,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
count.fill(0)
@ -1201,6 +1376,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1

View File

@ -0,0 +1,955 @@
import inspect
import logging
from math import ceil
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
StableDiffusionXLImg2ImgPipelineMixin,
)
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from ...chain.tile import make_tile_mask
from ...constants import LATENT_FACTOR
from ...params import Size
from ..utils import (
expand_latents,
parse_regions,
random_seed,
repair_nan,
resize_latent_shape,
)
logger = logging.getLogger(__name__)
DEFAULT_WINDOW = 64
DEFAULT_STRIDE = 16
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
def __init__(
self,
*args,
window: int = DEFAULT_WINDOW,
stride: int = DEFAULT_STRIDE,
**kwargs,
):
super().__init__(self, *args, **kwargs)
self.window = window
self.stride = stride
def set_window_size(self, window: int, stride: int):
self.window = window
self.stride = stride
def get_views(
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug(
"panorama generated %s views, %s by %s blocks",
total_num_blocks,
num_blocks_height,
num_blocks_width,
)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return (views, (h_end * 8, w_end * 8))
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_img2img(
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
):
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae_encoder(sample=image)[
0
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
if (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] == 0
):
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = np.concatenate(
[init_latents] * additional_image_per_prompt, axis=0
)
elif (
batch_size > init_latents.shape[0]
and batch_size % init_latents.shape[0] != 0
):
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = np.concatenate([init_latents], axis=0)
# add noise to latents using the timesteps
noise = generator.randn(*init_latents.shape).astype(dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents),
torch.from_numpy(noise),
torch.from_numpy(timestep),
)
return init_latents.numpy()
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents_text2img(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = generator.randn(*shape).astype(dtype)
elif latents.shape != shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
return latents
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
extra_step_kwargs = {}
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
return extra_step_kwargs
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def text2img(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`Optional[int]`, defaults to None):
The height in pixels of the generated image.
width (`Optional[int]`, defaults to None):
The width in pixels of the generated image.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`Optional[Union[str, list]]`):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
eta (`float`, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
1.0,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
(
region_prompt_embeds,
region_negative_prompt_embeds,
region_pooled_prompt_embeds,
region_negative_pooled_prompt_embeds,
) = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
if do_classifier_free_guidance:
region_prompt_embeds = np.concatenate(
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
)
add_region_embeds.append(
np.concatenate(
(
region_negative_pooled_prompt_embeds,
region_pooled_prompt_embeds,
),
axis=0,
)
)
region_embeds.append(region_prompt_embeds)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
latents = self.prepare_latents_text2img(
batch_size * num_images_per_prompt,
self.unet.config.get("in_channels", 4),
height,
width,
prompt_embeds.dtype,
generator,
latents,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = (original_size + crops_coords_top_left + target_size,)
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# Adapted from diffusers to extend it for other runtimes than ORT
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
# adjust latents
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_embeds=add_text_embeds,
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = (
region_noise_pred_uncond
+ guidance_scale
* (region_noise_pred_text - region_noise_pred_uncond)
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 100.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
latents = repair_nan(latents)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents
else:
latents = np.clip(latents, -4, +4)
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
def img2img(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.3,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
pooled_prompt_embeds: Optional[np.ndarray] = None,
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
output_type: str = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`Union[np.ndarray, PIL.Image.Image]`):
`Image`, or tensor representing an image batch which will be upscaled.
strength (`float`, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to 5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`Optional[Union[str, list]]`):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, defaults to 1):
The number of images to generate per prompt.
eta (`float`, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
A np.random.RandomState to make generation deterministic.
latents (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple.
callback (Optional[Callable], defaults to `None`):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
guidance_rescale (`float`, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 1. Define call parameters
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 2. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
# 3. Preprocess image
processor = VaeImageProcessor()
image = processor.preprocess(image).cpu().numpy()
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength
)
latent_timestep = np.repeat(
timesteps[:1], batch_size * num_images_per_prompt, axis=0
)
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
latents_dtype = prompt_embeds.dtype
image = image.astype(latents_dtype)
# 5. Prepare latent variables
latents = self.prepare_latents_img2img(
image,
latent_timestep,
batch_size,
num_images_per_prompt,
latents_dtype,
generator,
)
# 6. Prepare extra step kwargs
extra_step_kwargs = {}
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_eta:
extra_step_kwargs["eta"] = eta
height, width = latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 8. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
if do_classifier_free_guidance:
prompt_embeds = np.concatenate(
(negative_prompt_embeds, prompt_embeds), axis=0
)
add_text_embeds = np.concatenate(
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
)
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0
)
# 8. Panorama additions
views, resize = self.get_views(height, width, self.window, self.stride)
logger.trace("panorama resized latents to %s", resize)
count = np.zeros(resize_latent_shape(latents, resize))
value = np.zeros(resize_latent_shape(latents, resize))
latents = expand_latents(
latents,
random_seed(generator),
Size(resize[1], resize[0]),
sigma=self.scheduler.init_noise_sigma,
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
text_embeds=add_text_embeds,
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# remove extra margins
latents = latents[
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
]
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
for i in range(latents.shape[0])
]
)
image = self.watermark.apply_watermark(image)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
def __call__(
self,
*args,
**kwargs,
):
if "image" in kwargs or (
len(args) > 1
and (
isinstance(args[1], np.ndarray) or isinstance(args[1], PIL.Image.Image)
)
):
logger.debug("running img2img panorama XL pipeline")
return self.img2img(*args, **kwargs)
else:
logger.debug("running txt2img panorama XL pipeline")
return self.text2img(*args, **kwargs)
class ORTStableDiffusionXLPanoramaPipeline(
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
):
def __call__(self, *args, **kwargs):
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)

View File

@ -32,7 +32,7 @@ except ImportError:
}
from diffusers import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import (
DDIMScheduler,

View File

@ -1,63 +1,25 @@
###
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
# See also: https://github.com/huggingface/diffusers/pull/2158
###
from logging import getLogger
from typing import Any, Callable, List, Optional, Union
from typing import Any, List
import numpy as np
import PIL
import torch
from diffusers.pipeline_utils import ImagePipelineOutput
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import (
OnnxStableDiffusionUpscalePipeline as BasePipeline,
)
from diffusers.schedulers import DDPMScheduler
logger = getLogger(__name__)
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
ORT_TO_PT_TYPE = {
"float16": torch.float16,
"float32": torch.float32,
}
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class FakeConfig:
block_out_channels: List[int]
scaling_factor: float
def __init__(self) -> None:
self.block_out_channels = [128, 256, 512]
self.scaling_factor = 0.08333
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
def __init__(
self,
vae: OnnxRuntimeModel,
@ -80,260 +42,3 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
scheduler,
max_noise_level=max_noise_level,
)
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
# device, device only needed for Torch pipelines
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
# 4. Preprocess image
image = preprocess(image)
image = image.cpu()
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(
image.shape, generator=generator, device=device, dtype=latents_dtype
)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
noise_level = np.concatenate([noise_level] * image.shape[0])
# 6. Prepare latent variables
height, width = image.shape[2:]
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
NUM_LATENT_CHANNELS,
height,
width,
latents_dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
f" `num_channels_image`: {num_channels_image} "
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
timestep_dtype = next(
(
input.type
for input in self.unet.model.get_inputs()
if input.name == "timestep"
),
"tensor(float)",
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents] * 2)
if do_classifier_free_guidance
else latents
)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
# timestep to tensor
timestep = np.array([t], dtype=timestep_dtype)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level.astype(np.int64),
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# no positional arguments to text_encoder
text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device),
)
text_embeddings = text_embeddings[0]
bs_embed, seq_len, _ = text_embeddings.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
text_embeddings = text_embeddings.reshape(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device),
)
uncond_embeddings = uncond_embeddings[0]
seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.reshape(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
return text_embeddings

View File

@ -4,15 +4,16 @@ from typing import Any, List, Optional
from PIL import Image, ImageOps
from onnx_web.chain.highres import stage_highres
from ..chain import (
BlendDenoiseStage,
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleOutpaintStage,
)
from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image
@ -33,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
logger = getLogger(__name__)
def get_base_tile(params: ImageParams, size: Size) -> int:
if params.is_panorama():
tile = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile)
return tile
return params.unet_tile
def get_highres_tile(
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
) -> int:
if params.is_panorama() and server.has_feature("panorama-highres"):
return tile * highres.scale
return params.unet_tile
def run_txt2img_pipeline(
worker: WorkerContext,
server: ServerContext,
@ -43,10 +62,7 @@ def run_txt2img_pipeline(
highres: HighresParams,
) -> None:
# if using panorama, the pipeline will tile itself (views)
if params.is_panorama() or params.is_xl():
tile_size = max(params.tiles, size.width, size.height)
else:
tile_size = params.tiles
tile_size = get_base_tile(params, size)
# prepare the chain pipeline and first stage
chain = ChainPipeline()
@ -57,15 +73,21 @@ def run_txt2img_pipeline(
),
size=size,
prompt_index=0,
overlap=params.overlap,
overlap=params.vae_overlap,
)
# apply upscaling and correction, before highres
stage = StageParams(tile_size=params.tiles)
highres_size = get_highres_tile(server, params, highres, tile_size)
if params.is_panorama():
chain.stage(
BlendDenoiseStage(),
StageParams(tile_size=highres_size),
)
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=first_upscale,
@ -73,7 +95,7 @@ def run_txt2img_pipeline(
# apply highres
stage_highres(
stage,
StageParams(outscale=highres.scale, tile_size=highres_size),
params,
highres,
upscale,
@ -83,7 +105,7 @@ def run_txt2img_pipeline(
# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
params,
chain=chain,
upscale=after_upscale,
@ -92,11 +114,14 @@ def run_txt2img_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
images = chain.run(
worker, server, params, StageResult.empty(), callback=progress, latents=latents
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image(
server,
output,
@ -136,23 +161,26 @@ def run_img2img_pipeline(
source = f(server, source)
# prepare the chain pipeline and first stage
tile_size = get_base_tile(params, Size(*source.size))
chain = ChainPipeline()
stage = StageParams(
tile_size=params.tiles,
)
chain.stage(
BlendImg2ImgStage(),
stage,
StageParams(
tile_size=tile_size,
),
prompt_index=0,
strength=strength,
overlap=params.overlap,
overlap=params.vae_overlap,
)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(
outscale=first_upscale.outscale,
tile_size=tile_size,
),
params,
upscale=first_upscale,
chain=chain,
@ -162,13 +190,16 @@ def run_img2img_pipeline(
for _i in range(params.loopback):
chain.stage(
BlendImg2ImgStage(),
stage,
StageParams(
tile_size=tile_size,
),
strength=strength,
)
# highres, if selected
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres(
stage,
StageParams(tile_size=highres_size, outscale=highres.scale),
params,
highres,
upscale,
@ -178,7 +209,7 @@ def run_img2img_pipeline(
# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
StageParams(tile_size=tile_size, outscale=after_upscale.scale),
params,
upscale=after_upscale,
chain=chain,
@ -186,7 +217,9 @@ def run_img2img_pipeline(
# run and append the filtered source
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress)
images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
if source_filter is not None and source_filter != "none":
images.append(source)
@ -235,7 +268,7 @@ def run_inpaint_pipeline(
full_res_inpaint_padding: float,
) -> None:
logger.debug("building inpaint pipeline")
tile_size = params.tiles
tile_size = get_base_tile(params, size)
if mask is None:
# if no mask was provided, keep the full source image
@ -264,8 +297,12 @@ def run_inpaint_pipeline(
logger.debug("border zero: %s", border.isZero())
full_res_inpaint = full_res_inpaint and border.isZero()
if full_res_inpaint:
mask_left, mask_top, mask_right, mask_bottom = mask.getbbox()
logger.debug("mask bbox: %s", mask.getbbox())
bbox = mask.getbbox()
if bbox is None:
bbox = (0, 0, source.width, source.height)
logger.debug("mask bounding box: %s", bbox)
mask_left, mask_top, mask_right, mask_bottom = bbox
mask_width = mask_right - mask_left
mask_height = mask_bottom - mask_top
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
@ -322,16 +359,15 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
chain.stage(
UpscaleOutpaintStage(),
stage,
StageParams(tile_order=tile_order, tile_size=tile_size),
border=border,
mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
overlap=params.overlap,
overlap=params.vae_overlap,
prompt_index=0,
)
@ -339,15 +375,16 @@ def run_inpaint_pipeline(
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres(
stage,
StageParams(outscale=highres.scale, tile_size=highres_size),
params,
highres,
upscale,
@ -357,7 +394,7 @@ def run_inpaint_pipeline(
# apply upscaling and correction
stage_upscale_correction(
stage,
StageParams(outscale=after_upscale.outscale),
params,
upscale=after_upscale,
chain=chain,
@ -366,7 +403,14 @@ def run_inpaint_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress, latents=latents)
images = chain.run(
worker,
server,
params,
StageResult(images=[source]),
callback=progress,
latents=latents,
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -409,21 +453,22 @@ def run_upscale_pipeline(
) -> None:
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams(tile_size=params.tiles)
tile_size = get_base_tile(params, size)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
stage_upscale_correction(
stage,
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params,
upscale=first_upscale,
chain=chain,
)
# apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres(
stage,
StageParams(outscale=highres.scale, tile_size=highres_size),
params,
highres,
upscale,
@ -433,7 +478,7 @@ def run_upscale_pipeline(
# apply upscaling and correction, after highres
stage_upscale_correction(
stage,
StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
params,
upscale=after_upscale,
chain=chain,
@ -441,7 +486,9 @@ def run_upscale_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress)
images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -478,12 +525,18 @@ def run_blend_pipeline(
) -> None:
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
tile_size = get_base_tile(params, size)
chain.stage(
BlendMaskStage(),
StageParams(tile_size=tile_size),
stage_source=sources[1],
stage_mask=mask,
)
# apply upscaling and correction
stage_upscale_correction(
stage,
StageParams(outscale=upscale.outscale),
params,
upscale=upscale,
chain=chain,
@ -491,7 +544,9 @@ def run_blend_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain(worker, server, params, sources, callback=progress)
images = chain.run(
worker, server, params, StageResult(images=sources), callback=progress
)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)

View File

@ -3,23 +3,27 @@ from copy import deepcopy
from logging import getLogger
from math import ceil
from re import Pattern, compile
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from diffusers import OnnxStableDiffusionPipeline
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
from ..params import ImageParams, Size
logger = getLogger(__name__)
LATENT_CHANNELS = 4
LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77
ANY_TOKEN = compile(r"\<([^\>]*)\>")
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
REGION_TOKEN = compile(
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
)
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
@ -84,8 +88,8 @@ def expand_prompt(
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: Optional[int] = 0,
) -> "np.NDArray":
skip_clip_states: int = 0,
) -> np.ndarray:
# self provides:
# tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel
@ -140,6 +144,7 @@ def expand_prompt(
last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0:
# TODO: why is this normalized?
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(
torch.from_numpy(
@ -219,20 +224,25 @@ def expand_prompt(
return prompt_embeds
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
name, weight = group
return (name, float(weight))
def get_tokens_from_prompt(
prompt: str, pattern: Pattern
prompt: str,
pattern: Pattern,
parser=parse_float_group,
) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt
tokens = []
next_match = pattern.search(remaining_prompt)
while next_match is not None:
logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups()
tokens.append((name, float(weight)))
group = next_match.groups()
tokens.append(parser(group))
# remove this match and look for another
remaining_prompt = (
remaining_prompt[: next_match.start()]
@ -251,6 +261,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
def random_seed(generator=None) -> int:
if generator is None:
generator = np.random
return generator.randint(np.iinfo(np.int32).max)
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""
From https://www.travelneil.com/stable-diffusion-updates.html.
@ -266,6 +283,25 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
return image_latents
def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
sigma: float = 1.0,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents * np.float64(sigma)
def resize_latent_shape(
latents: np.ndarray,
size: Tuple[int, int],
) -> Tuple[int, int, int, int]:
return (latents.shape[0], latents.shape[1], *size)
def get_tile_latents(
full_latents: np.ndarray,
seed: int,
@ -290,14 +326,8 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape != full_latents.shape and (
tile_latents.shape[2] < t or tile_latents.shape[3] < t
):
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
] = tile_latents
tile_latents = extra_latents
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
tile_latents = expand_latents(tile_latents, seed, size)
return tile_latents
@ -369,12 +399,15 @@ def encode_prompt(
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> List[np.ndarray]:
"""
TODO: does not work with SDXL, fix or turn into a pipeline patch
"""
return [
pipe._encode_prompt(
prompt,
remove_tokens(prompt),
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=neg_prompt,
negative_prompt=remove_tokens(neg_prompt),
)
for prompt, neg_prompt in prompt_pairs
]
@ -444,3 +477,71 @@ def slice_prompt(prompt: str, slice: int) -> str:
return parts[min(slice, len(parts) - 1)]
else:
return prompt
Region = Tuple[
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
]
def parse_region_group(group: Tuple[str, ...]) -> Region:
top, left, bottom, right, weight, feather, prompt = group
# break down the feather section
feather_radius, *feather_edges = feather.split("_")
if len(feather_edges) == 0:
feather_edges = "TLBR"
else:
feather_edges = "".join(feather_edges)
return (
int(top),
int(left),
int(bottom),
int(right),
float(weight),
(
float(feather_radius),
(
"T" in feather_edges,
"L" in feather_edges,
"B" in feather_edges,
"R" in feather_edges,
),
),
prompt,
)
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
Reseed = Tuple[int, int, int, int, int]
def parse_reseed_group(group) -> Region:
top, left, bottom, right, seed = group
return (
int(top),
int(left),
int(bottom),
int(right),
int(seed),
)
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
def skip_group(group) -> Any:
return group
def remove_tokens(prompt: Optional[str]) -> Optional[str]:
if prompt is None:
return prompt
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group)
return remainder

View File

@ -12,6 +12,11 @@ try:
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
try:
from diffusers import LCMScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as LCMScheduler
try:
from diffusers import UniPCMultistepScheduler
except ImportError:

View File

@ -8,7 +8,7 @@ def mask_filter_none(
) -> Image.Image:
width, height = dims
noise = Image.new("RGB", (width, height), fill)
noise = Image.new(mask.mode, (width, height), fill)
noise.paste(mask, origin)
return noise

View File

@ -17,21 +17,21 @@ def noise_source_fill_edge(
"""
width, height = dims
noise = Image.new("RGB", (width, height), fill)
noise = Image.new(source.mode, (width, height), fill)
noise.paste(source, origin)
return noise
def noise_source_fill_mask(
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
) -> Image.Image:
"""
Fill the whole canvas, no source or noise.
"""
width, height = dims
noise = Image.new("RGB", (width, height), fill)
noise = Image.new(source.mode, (width, height), fill)
return noise
@ -52,7 +52,7 @@ def noise_source_gaussian(
def noise_source_uniform(
_source: Image.Image, dims: Point, _origin: Point, **kw
source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -61,6 +61,7 @@ def noise_source_uniform(
noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))
for x in range(width):
@ -68,11 +69,11 @@ def noise_source_uniform(
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise
return noise.convert(source.mode)
def noise_source_normal(
_source: Image.Image, dims: Point, _origin: Point, **kw
source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -81,6 +82,7 @@ def noise_source_normal(
noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))
for x in range(width):
@ -88,13 +90,13 @@ def noise_source_normal(
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise
return noise.convert(source.mode)
def noise_source_histogram(
source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image:
r, g, b = source.split()
r, g, b, *_a = source.split()
width, height = dims
size = width * height
@ -112,6 +114,7 @@ def noise_source_histogram(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))
for x in range(width):
@ -119,4 +122,4 @@ def noise_source_histogram(
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
return noise
return noise.convert(source.mode)

View File

@ -47,7 +47,7 @@ def source_filter_noise(
source: Image.Image,
strength: float = 0.5,
):
noise = noise_source_histogram(source, source.size)
noise = noise_source_histogram(source, source.size, (0, 0))
return ImageChops.blend(source, noise, strength)

View File

@ -1,3 +1,5 @@
from typing import Tuple
from PIL import Image, ImageChops
from ..params import Border, Size
@ -13,12 +15,12 @@ def expand_image(
fill="white",
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
size = Size(*source.size).add_border(expand)
size = tuple(size)
origin = (expand.left, expand.top)
full_source = Image.new("RGB", size, fill)
full_source = Image.new(source.mode, size, fill)
full_source.paste(source, origin)
# new mask pixels need to be filled with white so they will be replaced

View File

@ -23,6 +23,7 @@ from .server.load import (
load_platforms,
load_wildcards,
)
from .server.plugin import load_plugins, register_plugins
from .server.static import register_static_routes
from .server.utils import check_paths
from .utils import is_debug
@ -43,15 +44,32 @@ def main():
server = ServerContext.from_environ()
apply_patches(server)
check_paths(server)
# debug options
if server.debug:
import debugpy
debugpy.listen(5678)
logger.warning("waiting for debugger")
debugpy.wait_for_client()
gc.set_debug(gc.DEBUG_STATS)
# register plugins
exports = load_plugins(server)
success = register_plugins(exports)
if success:
logger.info("all plugins loaded successfully")
else:
logger.warning("error loading plugins")
# load additional resources
load_extras(server)
load_models(server)
load_params(server)
load_platforms(server)
load_wildcards(server)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
# misc server options
if not server.show_progress:
disable_progress_bar()
disable_progress_bars()

View File

@ -1,18 +1,21 @@
from typing import Literal
from typing import List, Literal
NetworkType = Literal["inversion", "lora"]
NetworkType = Literal["control", "inversion", "lora"]
class NetworkModel:
name: str
tokens: List[str]
type: NetworkType
def __init__(self, name: str, type: NetworkType) -> None:
def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
self.name = name
self.tokens = tokens or []
self.type = type
def tojson(self):
return {
"name": self.name,
"tokens": self.tokens,
"type": self.type,
}

View File

@ -57,7 +57,7 @@ def json_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
parent: Dict = None,
parent: Optional[Dict] = None,
) -> Any:
json = {
"input_size": size.tojson(),
@ -158,6 +158,7 @@ def make_output_name(
size: Size,
extras: Optional[List[Optional[Param]]] = None,
count: Optional[int] = None,
offset: int = 0,
) -> List[str]:
count = count or params.batch
now = int(time())
@ -183,7 +184,7 @@ def make_output_name(
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(count)
for i in range(offset, count + offset)
]

View File

@ -14,7 +14,7 @@ Point = Tuple[int, int]
class SizeChart(IntEnum):
unlimited = 0
micro = 64
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
auto = 512 # auto tile size
@ -25,6 +25,7 @@ class SizeChart(IntEnum):
hd16k = 2**14
hd32k = 2**15
hd64k = 2**16
max = 2**32 # should be a reasonable upper limit for now
class TileOrder:
@ -140,7 +141,7 @@ class DeviceParams:
if self.options is None:
return self.provider
else:
return self.provider # (self.provider, self.options)
return (self.provider, self.options)
def sess_options(self, cache=True) -> SessionOptions:
if cache and self.sess_options_cache is not None:
@ -201,11 +202,14 @@ class ImageParams:
batch: int
control: Optional[NetworkModel]
input_prompt: str
input_negative_prompt: str
input_negative_prompt: Optional[str]
loopback: int
tiled_vae: bool
tiles: int
overlap: float
unet_tile: int
unet_overlap: float
vae_tile: int
vae_overlap: float
denoise: int
def __init__(
self,
@ -224,9 +228,11 @@ class ImageParams:
input_negative_prompt: Optional[str] = None,
loopback: int = 0,
tiled_vae: bool = False,
tiles: int = 512,
overlap: float = 0.25,
stride: int = 64,
unet_overlap: float = 0.25,
unet_tile: int = 512,
vae_overlap: float = 0.25,
vae_tile: int = 512,
denoise: int = 3,
) -> None:
self.model = model
self.pipeline = pipeline
@ -243,14 +249,16 @@ class ImageParams:
self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback
self.tiled_vae = tiled_vae
self.tiles = tiles
self.overlap = overlap
self.stride = stride
self.unet_overlap = unet_overlap
self.unet_tile = unet_tile
self.vae_overlap = vae_overlap
self.vae_tile = vae_tile
self.denoise = denoise
def do_cfg(self):
return self.cfg > 1.0
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
pipeline = pipeline or self.pipeline
# if the correct pipeline was already requested, simply use that
@ -259,7 +267,14 @@ class ImageParams:
# otherwise, check for additional allowed pipelines
if group == "img2img":
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]:
if pipeline in [
"controlnet",
"img2img-sdxl",
"lpw",
"panorama",
"panorama-sdxl",
"pix2pix",
]:
return pipeline
elif pipeline == "txt2img-sdxl":
return "img2img-sdxl"
@ -267,7 +282,7 @@ class ImageParams:
if pipeline in ["controlnet", "lpw", "panorama"]:
return pipeline
elif group == "txt2img":
if pipeline in ["lpw", "panorama", "txt2img-sdxl"]:
if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]:
return pipeline
logger.debug("pipeline %s is not valid for %s", pipeline, group)
@ -280,7 +295,7 @@ class ImageParams:
return self.pipeline == "lpw"
def is_panorama(self):
return self.pipeline == "panorama"
return self.pipeline in ["panorama", "panorama-sdxl"]
def is_pix2pix(self):
return self.pipeline == "pix2pix"
@ -305,9 +320,11 @@ class ImageParams:
"input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback,
"tiled_vae": self.tiled_vae,
"tiles": self.tiles,
"overlap": self.overlap,
"stride": self.stride,
"unet_overlap": self.unet_overlap,
"unet_tile": self.unet_tile,
"vae_overlap": self.vae_overlap,
"vae_tile": self.vae_tile,
"denoise": self.denoise,
}
def with_args(self, **kwargs):
@ -327,9 +344,11 @@ class ImageParams:
kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback),
kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles),
kwargs.get("overlap", self.overlap),
kwargs.get("stride", self.stride),
kwargs.get("unet_overlap", self.unet_overlap),
kwargs.get("unet_tile", self.unet_tile),
kwargs.get("vae_overlap", self.vae_overlap),
kwargs.get("vae_tile", self.vae_tile),
kwargs.get("denoise", self.denoise),
)
@ -351,6 +370,17 @@ class StageParams:
self.tile_order = tile_order
self.tile_size = tile_size
def with_args(
self,
**kwargs,
):
return StageParams(
name=kwargs.get("name", self.name),
outscale=kwargs.get("outscale", self.outscale),
tile_order=kwargs.get("tile_order", self.tile_order),
tile_size=kwargs.get("tile_size", self.tile_size),
)
class UpscaleParams:
def __init__(
@ -459,10 +489,14 @@ class HighresParams:
self.method = method
self.iterations = iterations
def outscale(self) -> int:
return self.scale**self.iterations
def resize(self, size: Size) -> Size:
outscale = self.outscale()
return Size(
size.width * (self.scale**self.iterations),
size.height * (self.scale**self.iterations),
size.width * outscale,
size.height * outscale,
)
def tojson(self):

View File

@ -1,12 +1,14 @@
from io import BytesIO
from logging import getLogger
from os import path
from typing import Any, Dict
from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
from ..chain import CHAIN_STAGES, ChainPipeline
from ..chain.result import StageResult
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.run import (
run_blend_pipeline,
@ -17,7 +19,7 @@ from ..diffusers.run import (
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
@ -49,10 +51,11 @@ from .load import (
get_wildcard_data,
)
from .params import (
border_from_request,
highres_from_request,
build_border,
build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request,
upscale_from_request,
)
from .utils import wrap_route
@ -167,8 +170,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys())
)
@ -216,12 +219,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size)
output = make_output_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0]
pool.submit(
@ -250,7 +253,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
if mask_file is None:
return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
size = Size(source.width, source.height)
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
@ -270,9 +273,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
)
device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()
highres = highres_from_request()
expand = build_border()
upscale = build_upscale()
highres = build_highres()
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -340,8 +343,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
@ -366,47 +369,70 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]
def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json:
logger.debug("chain pipeline request with JSON body")
data = request.get_json()
else:
logger.debug(
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = load_config_str(body)
schema = load_config("./schemas/chain.yaml")
schema = load_config("./schemas/chain.yaml")
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
device, params, size = pipeline_from_request(server)
output = make_output_name(server, "chain", params, size)
job_name = output[0]
replace_wildcards(params, get_wildcard_data())
device, base_params, base_size = pipeline_from_json(
server, data=data.get("defaults")
)
# start building the pipeline
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {})
kwargs: Dict[str, Any] = stage_data.get("params", {})
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
# TODO: combine base params with stage params
_device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
# remove parsed keys, like model names (which become paths)
for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
kwargs.pop(pop_key)
if "seed" in kwargs and kwargs["seed"] == -1:
kwargs.pop("seed")
# replace kwargs with parsed versions
kwargs["params"] = params
kwargs["size"] = size
border = build_border(kwargs)
kwargs["border"] = border
upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale
# prepare the stage metadata
stage = StageParams(
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),
tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
# load any images related to this stage
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
@ -436,20 +462,25 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
)
job_name = output[0]
# build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
pool.submit(
job_name,
pipeline,
server,
params,
empty_source,
output=output[0],
size=size,
base_params,
StageResult.empty(),
output=output,
size=base_size,
needs_device=device,
)
return jsonify(json_params(output, params, size))
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -471,7 +502,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
sources.append(source)
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
upscale = build_upscale()
output = make_output_name(server, "upscale", params, size)
job_name = output[0]

View File

@ -5,18 +5,44 @@ from typing import List, Optional
import torch
from ..utils import get_boolean
from ..utils import get_boolean, get_list
from .model_cache import ModelCache
logger = getLogger(__name__)
DEFAULT_ANY_PLATFORM = True
DEFAULT_CACHE_LIMIT = 5
DEFAULT_JOB_LIMIT = 10
DEFAULT_IMAGE_FORMAT = "png"
DEFAULT_SERVER_VERSION = "v0.10.0"
DEFAULT_SHOW_PROGRESS = True
DEFAULT_WORKER_RETRIES = 3
class ServerContext:
bundle_path: str
model_path: str
output_path: str
params_path: str
cors_origin: str
any_platform: bool
block_platforms: List[str]
default_platform: str
image_format: str
cache_limit: int
cache_path: str
show_progress: bool
optimizations: List[str]
extra_models: List[str]
job_limit: int
memory_limit: int
admin_token: str
server_version: str
worker_retries: int
feature_flags: List[str]
plugins: List[str]
debug: bool
def __init__(
self,
bundle_path: str = ".",
@ -24,19 +50,23 @@ class ServerContext:
output_path: str = ".",
params_path: str = ".",
cors_origin: str = "*",
any_platform: bool = True,
any_platform: bool = DEFAULT_ANY_PLATFORM,
block_platforms: Optional[List[str]] = None,
default_platform: Optional[str] = None,
image_format: str = DEFAULT_IMAGE_FORMAT,
cache_limit: int = DEFAULT_CACHE_LIMIT,
cache_path: Optional[str] = None,
show_progress: bool = True,
show_progress: bool = DEFAULT_SHOW_PROGRESS,
optimizations: Optional[List[str]] = None,
extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None,
admin_token: Optional[str] = None,
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
feature_flags: Optional[List[str]] = None,
plugins: Optional[List[str]] = None,
debug: bool = False,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -56,6 +86,10 @@ class ServerContext:
self.memory_limit = memory_limit
self.admin_token = admin_token or token_urlsafe()
self.server_version = server_version
self.worker_retries = worker_retries
self.feature_flags = feature_flags or []
self.plugins = plugins or []
self.debug = debug
self.cache = ModelCache(self.cache_limit)
@ -72,26 +106,41 @@ class ServerContext:
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
),
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
),
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit,
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
),
worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
)
def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags
def has_optimization(self, opt: str) -> bool:
return opt in self.optimizations
def torch_dtype(self):
if "torch-fp16" in self.optimizations:
if self.has_optimization("torch-fp16"):
return torch.float16
else:
return torch.float32

View File

@ -134,25 +134,44 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module")
try:
import basicsr.utils.download_util
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import basicsr utils for patching")
except AttributeError:
logger.warning("unable to patch basicsr utils")
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
codeformer.facelib.utils.misc.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import codeformer utils for patching")
except AttributeError:
logger.warning("unable to patch codeformer utils")
def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module")
try:
import facexlib.utils
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
except ImportError:
logger.info("unable to import facexlib for patching")
except AttributeError:
logger.warning("unable to patch facexlib utils")
def apply_patches(server: ServerContext):

View File

@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {}
extra_tokens: Dict[str, List[str]] = {}
def get_config_params():
@ -160,9 +161,10 @@ def load_extras(server: ServerContext):
"""
global extra_hashes
global extra_strings
global extra_tokens
labels = {}
strings = {}
labels: Dict[str, str] = {}
strings: Dict[str, Any] = {}
extra_schema = load_config("./schemas/extras.yaml")
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
else:
labels[model_name] = model["label"]
if "tokens" in model:
logger.debug(
"collecting tokens for model %s from %s",
model_name,
file,
)
extra_tokens[model_name] = model["tokens"]
if "inversions" in model:
for inversion in model["inversions"]:
if "label" in inversion:
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
)
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
network_models.extend(
[NetworkModel(model, "inversion") for model in inversion_models]
[
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
for model in inversion_models
]
)
lora_models = list_model_globs(
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
base_path=path.join(server.model_path, "lora"),
)
logger.debug("loaded LoRA models from disk: %s", lora_models)
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
network_models.extend(
[
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
for model in lora_models
]
)
def load_params(server: ServerContext) -> None:
@ -397,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
):
if potential == "cuda" or potential == "rocm":
for i in range(torch.cuda.device_count()):
options = {
options: Dict[str, Union[int, str]] = {
"device_id": i,
}

View File

@ -51,7 +51,7 @@ class ModelCache:
return
for i in range(len(cache)):
t, k, v = cache[i]
t, k, _v = cache[i]
if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key)
cache[i] = (tag, key, value)

View File

@ -1,10 +1,10 @@
from logging import getLogger
from typing import Tuple
from typing import Dict, Optional, Tuple
import numpy as np
from flask import request
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.utils import random_seed
from ..params import (
Border,
DeviceParams,
@ -34,143 +34,122 @@ from .utils import get_model_path
logger = getLogger(__name__)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
def build_device(
_server: ServerContext,
data: Dict[str, str],
) -> Optional[DeviceParams]:
# platform stuff
device = None
device_name = request.args.get("platform")
device_name = data.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
return device
def build_params(
server: ServerContext,
default_pipeline: str,
data: Dict[str, str],
) -> ImageParams:
# diffusion model
model = get_not_empty(request.args, "model", get_config_value("model"))
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
control = None
control_name = request.args.get("control")
control_name = data.get("control")
for network in get_network_models():
if network.name == control_name:
control = network
# pipeline stuff
pipeline = get_from_list(
request.args, "pipeline", get_available_pipelines(), default_pipeline
data, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
# prompt does not come from config
prompt = request.args.get("prompt", "")
negative_prompt = request.args.get("negativePrompt", None)
prompt = data.get("prompt", "")
negative_prompt = data.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
# image params
batch = get_and_clamp_int(
request.args,
data,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
request.args,
data,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
request.args,
data,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
request.args,
data,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
request.args,
data,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
request.args,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
unet_overlap = get_and_clamp_float(
data,
"unet_overlap",
get_config_value("unet_overlap"),
get_config_value("unet_overlap", "max"),
get_config_value("unet_overlap", "min"),
)
width = get_and_clamp_int(
request.args,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
unet_tile = get_and_clamp_int(
data,
"unet_tile",
get_config_value("unet_tile"),
get_config_value("unet_tile", "max"),
get_config_value("unet_tile", "min"),
)
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
request.args,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
vae_overlap = get_and_clamp_float(
data,
"vae_overlap",
get_config_value("vae_overlap"),
get_config_value("vae_overlap", "max"),
get_config_value("vae_overlap", "min"),
)
overlap = get_and_clamp_float(
request.args,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
stride = get_and_clamp_int(
request.args,
"stride",
get_config_value("stride"),
get_config_value("stride", "max"),
get_config_value("stride", "min"),
vae_tile = get_and_clamp_int(
data,
"vae_tile",
get_config_value("vae_tile"),
get_config_value("vae_tile", "max"),
get_config_value("vae_tile", "min"),
)
if stride > tiles:
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
stride = tiles
seed = int(request.args.get("seed", -1))
seed = int(data.get("seed", -1))
if seed == -1:
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
seed = random_seed()
params = ImageParams(
model_path,
@ -186,38 +165,65 @@ def pipeline_from_request(
control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
stride=stride,
unet_overlap=unet_overlap,
unet_tile=unet_tile,
vae_overlap=vae_overlap,
vae_tile=vae_tile,
)
size = Size(width, height)
return (device, params, size)
return params
def border_from_request() -> Border:
def build_size(
_server: ServerContext,
data: Dict[str, str],
) -> Size:
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
return Size(width, height)
def build_border(
data: Dict[str, str] = None,
) -> Border:
if data is None:
data = request.args
left = get_and_clamp_int(
request.args,
data,
"left",
get_config_value("left"),
get_config_value("left", "max"),
get_config_value("left", "min"),
)
right = get_and_clamp_int(
request.args,
data,
"right",
get_config_value("right"),
get_config_value("right", "max"),
get_config_value("right", "min"),
)
top = get_and_clamp_int(
request.args,
data,
"top",
get_config_value("top"),
get_config_value("top", "max"),
get_config_value("top", "min"),
)
bottom = get_and_clamp_int(
request.args,
data,
"bottom",
get_config_value("bottom"),
get_config_value("bottom", "max"),
@ -227,46 +233,51 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
def build_upscale(
data: Dict[str, str] = None,
) -> UpscaleParams:
if data is None:
data = request.args
denoise = get_and_clamp_float(
request.args,
data,
"denoise",
get_config_value("denoise"),
get_config_value("denoise", "max"),
get_config_value("denoise", "min"),
)
scale = get_and_clamp_int(
request.args,
data,
"scale",
get_config_value("scale"),
get_config_value("scale", "max"),
get_config_value("scale", "min"),
)
outscale = get_and_clamp_int(
request.args,
data,
"outscale",
get_config_value("outscale"),
get_config_value("outscale", "max"),
get_config_value("outscale", "min"),
)
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
correction = get_from_list(request.args, "correction", get_correction_models())
faces = get_not_empty(request.args, "faces", "false") == "true"
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(data, "correction", get_correction_models())
faces = get_not_empty(data, "faces", "false") == "true"
face_outscale = get_and_clamp_int(
request.args,
data,
"faceOutscale",
get_config_value("faceOutscale"),
get_config_value("faceOutscale", "max"),
get_config_value("faceOutscale", "min"),
)
face_strength = get_and_clamp_float(
request.args,
data,
"faceStrength",
get_config_value("faceStrength"),
get_config_value("faceStrength", "max"),
get_config_value("faceStrength", "min"),
)
upscale_order = request.args.get("upscaleOrder", "correction-first")
upscale_order = data.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
@ -282,37 +293,43 @@ def upscale_from_request() -> UpscaleParams:
)
def highres_from_request() -> HighresParams:
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
def build_highres(
data: Dict[str, str] = None,
) -> HighresParams:
if data is None:
data = request.args
enabled = get_boolean(data, "highres", get_config_value("highres"))
iterations = get_and_clamp_int(
request.args,
data,
"highresIterations",
get_config_value("highresIterations"),
get_config_value("highresIterations", "max"),
get_config_value("highresIterations", "min"),
)
method = get_from_list(request.args, "highresMethod", get_highres_methods())
method = get_from_list(data, "highresMethod", get_highres_methods())
scale = get_and_clamp_int(
request.args,
data,
"highresScale",
get_config_value("highresScale"),
get_config_value("highresScale", "max"),
get_config_value("highresScale", "min"),
)
steps = get_and_clamp_int(
request.args,
data,
"highresSteps",
get_config_value("highresSteps"),
get_config_value("highresSteps", "max"),
get_config_value("highresSteps", "min"),
)
strength = get_and_clamp_float(
request.args,
data,
"highresStrength",
get_config_value("highresStrength"),
get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"),
)
return HighresParams(
enabled,
scale,
@ -321,3 +338,50 @@ def highres_from_request() -> HighresParams:
method=method,
iterations=iterations,
)
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
def pipeline_from_json(
server: ServerContext,
data: Dict[str, str],
default_pipeline: str = "txt2img",
) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
params.steps,
params.scheduler,
params.model,
params.pipeline,
device or "any device",
size.width,
size.height,
params.cfg,
params.seed,
params.prompt,
)
return (device, params, size)

View File

@ -0,0 +1,61 @@
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
self.pipelines = pipelines or {}
self.stages = stages or {}
PluginModule = Callable[[ServerContext], PluginExports]
def load_plugins(server: ServerContext) -> PluginExports:
combined_exports = PluginExports()
for plugin in server.plugins:
logger.info("loading plugin module: %s", plugin)
try:
module: PluginModule = import_module(plugin)
exports = module(server)
for name, pipeline in exports.pipelines.items():
if name in combined_exports.pipelines:
logger.warning(
"multiple plugins exported a pipeline named %s", name
)
else:
combined_exports.pipelines[name] = pipeline
for name, stage in exports.stages.items():
if name in combined_exports.stages:
logger.warning("multiple plugins exported a stage named %s", name)
else:
combined_exports.stages[name] = stage
except Exception:
logger.exception("error importing plugin")
return combined_exports
def register_plugins(exports: PluginExports) -> bool:
success = True
for name, pipeline in exports.pipelines.items():
success = success and add_pipeline(name, pipeline)
for name, stage in exports.stages.items():
success = success and add_stage(name, stage)
return success

View File

@ -18,6 +18,11 @@ logger = getLogger(__name__)
SAFE_CHARS = "._-"
def split_list(val: str) -> List[str]:
parts = [part.strip() for part in val.split(",")]
return [part for part in parts if len(part) > 0]
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
return path.join(base, tail_path)
@ -28,7 +33,16 @@ def is_debug() -> bool:
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes")
val = args.get(key, str(default_value))
if isinstance(val, bool):
return val
return val.lower() in ("1", "t", "true", "y", "yes")
def get_list(args: Any, key: str, default="") -> List[str]:
return split_list(args.get(key, default))
def get_and_clamp_float(
@ -61,13 +75,13 @@ def get_from_list(
def get_from_map(
args: Any, key: str, values: Dict[str, TElem], default: TElem
args: Any, key: str, values: Dict[str, TElem], default_key: str
) -> TElem:
selected = args.get(key, default)
selected = args.get(key, default_key)
if selected in values:
return values[selected]
else:
return values[default]
return values[default_key]
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
@ -195,6 +209,8 @@ def load_config(file: str) -> Dict:
return load_yaml(file)
elif ext in [".json"]:
return load_json(file)
else:
raise ValueError("unknown config file extension")
def load_config_str(raw: str) -> Dict:

View File

@ -25,6 +25,7 @@ class WorkerContext:
idle: "Value[bool]"
timeout: float
retries: int
initial_retries: int
def __init__(
self,
@ -36,6 +37,8 @@ class WorkerContext:
progress: "Queue[ProgressCommand]",
active_pid: "Value[int]",
idle: "Value[bool]",
retries: int,
timeout: float,
):
self.job = None
self.name = name
@ -47,12 +50,13 @@ class WorkerContext:
self.active_pid = active_pid
self.last_progress = None
self.idle = idle
self.timeout = 1.0
self.retries = 3 # TODO: get from env
self.initial_retries = retries
self.retries = retries
self.timeout = timeout
def start(self, job: str) -> None:
self.job = job
self.retries = 3
self.retries = self.initial_retries
self.set_cancel(cancel=False)
self.set_idle(idle=False)
@ -82,7 +86,7 @@ class WorkerContext:
return 0
def get_progress_callback(self) -> ProgressCallback:
from ..chain.base import ChainProgress
from ..chain.pipeline import ChainProgress
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step

View File

@ -86,15 +86,15 @@ class DevicePoolExecutor:
self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock()
def start(self) -> None:
def start(self, *args) -> None:
self.create_health_worker()
self.create_logger_worker()
self.create_progress_worker()
for device in self.devices:
self.create_device_worker(device)
self.create_device_worker(device, *args)
def create_device_worker(self, device: DeviceParams) -> None:
def create_device_worker(self, device: DeviceParams, *args) -> None:
name = device.device
# always recreate queues
@ -124,15 +124,17 @@ class DevicePoolExecutor:
pending=self.pending[name],
active_pid=current,
idle=self.worker_idle[name],
retries=self.server.worker_retries,
timeout=self.progress_interval,
)
self.context[name] = context
worker = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
args=(context, self.server, *args),
daemon=True,
)
worker.daemon = True
self.workers[name] = worker
logger.debug("starting worker for device %s", device)

View File

@ -27,10 +27,14 @@ MEMORY_ERRORS = [
]
def worker_main(worker: WorkerContext, server: ServerContext):
apply_patches(server)
def worker_main(
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
):
setproctitle("onnx-web worker: %s" % (worker.device.device))
if patch:
apply_patches(server)
logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)
@ -46,7 +50,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
getpid(),
worker.get_active(),
)
exit(EXIT_REPLACED)
return exit(EXIT_REPLACED)
# wait briefly for the next job
job = worker.pending.get(timeout=worker.timeout)
@ -69,15 +73,15 @@ def worker_main(worker: WorkerContext, server: ServerContext):
except KeyboardInterrupt:
logger.debug("worker got keyboard interrupt")
worker.fail()
exit(EXIT_INTERRUPT)
return exit(EXIT_INTERRUPT)
except RetryException:
logger.exception("retry error in worker, exiting")
worker.fail()
exit(EXIT_ERROR)
return exit(EXIT_ERROR)
except ValueError:
logger.exception("value error in worker, exiting")
worker.fail()
exit(EXIT_ERROR)
return exit(EXIT_ERROR)
except Exception as e:
e_str = str(e)
# restart the worker on memory errors
@ -85,7 +89,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
if e_mem in e_str:
logger.error("detected out-of-memory error, exiting: %s", e)
worker.fail()
exit(EXIT_MEMORY)
return exit(EXIT_MEMORY)
# carry on for other errors
logger.exception(

View File

@ -98,7 +98,7 @@
"highresSteps": {
"default": 0,
"min": 1,
"max": 200,
"max": 500,
"step": 1
},
"highresStrength": {
@ -141,12 +141,6 @@
"max": 4,
"step": 1
},
"overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"pipeline": {
"default": "",
"keys": [
@ -188,7 +182,7 @@
"steps": {
"default": 25,
"min": 1,
"max": 200,
"max": 300,
"step": 1
},
"strength": {
@ -197,21 +191,9 @@
"max": 1,
"step": 0.01
},
"stride": {
"default": 128,
"min": 64,
"max": 512,
"step": 64
},
"tiledVAE": {
"tiled_vae": {
"default": false
},
"tiles": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"tileOrder": {
"default": "spiral",
"keys": [
@ -225,6 +207,18 @@
"max": 1024,
"step": 8
},
"unet_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"unet_tile": {
"default": 512,
"min": 128,
"max": 2048,
"step": 128
},
"upscaleOrder": {
"default": "correction-first",
"keys": [
@ -237,6 +231,18 @@
"default": "",
"keys": []
},
"vae_overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"vae_tile": {
"default": 512,
"min": 256,
"max": 1024,
"step": 128
},
"width": {
"default": 512,
"min": 128,

View File

@ -9,7 +9,9 @@ skip_glob = ["*/lpw.py"]
[tool.mypy]
# ignore_missing_imports = true
exclude = [
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
"onnx_web.diffusers.pipelines.controlnet",
"onnx_web.diffusers.pipelines.lpw",
"onnx_web.diffusers.pipelines.pix2pix"
]
[[tool.mypy.overrides]]
@ -27,8 +29,10 @@ module = [
"compel",
"controlnet_aux",
"cv2",
"debugpy",
"diffusers",
"diffusers.configuration_utils",
"diffusers.image_processor",
"diffusers.loaders",
"diffusers.models.attention_processor",
"diffusers.models.autoencoder_kl",
@ -41,9 +45,10 @@ module = [
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.onnx_utils",
"diffusers.pipelines.paint_by_example",
"diffusers.pipelines.pipeline_utils",
"diffusers.pipelines.stable_diffusion",
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
"diffusers.pipeline_utils",
"diffusers.pipelines.stable_diffusion_xl",
"diffusers.schedulers",
"diffusers.utils.logging",
"facexlib.utils",
@ -56,11 +61,17 @@ module = [
"mediapipe",
"onnxruntime",
"onnxruntime.transformers.float16",
"optimum.exporters.onnx",
"optimum.onnxruntime",
"optimum.onnxruntime.modeling_diffusion",
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
"optimum.pipelines.diffusers.pipeline_utils",
"piexif",
"piexif.helper",
"realesrgan",
"realesrgan.archs.srvgg_arch",
"safetensors",
"scipy",
"timm.models.layers",
"transformers",
"win10toast"

View File

@ -46,17 +46,31 @@ $defs:
patternProperties:
"^[-_A-Za-z]+$":
oneOf:
- type: boolean
- type: number
- type: string
- type: "null"
request_chain:
type: array
items:
$ref: "#/$defs/request_stage"
request_defaults:
type: object
properties:
txt2img:
$ref: "#/$defs/image_params"
img2img:
$ref: "#/$defs/image_params"
type: object
additionalProperties: False
required: [stages]
properties:
defaults:
$ref: "#/$defs/request_defaults"
platform:
type: string
stages:
$ref: "#/$defs/request_chain"

View File

@ -10,34 +10,53 @@ $defs:
- type: number
- type: string
lora_network:
tensor_format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
embedding_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
$ref: "#/$defs/tensor_format"
label:
type: string
weight:
type: number
textual_inversion_network:
type: object
required: [name, source]
properties:
name:
type: string
source:
type: string
format:
model:
type: string
enum: [concept, embeddings]
label:
name:
type: string
source:
type: string
token:
type: string
type:
type: string
const: inversion # TODO: add embedding
weight:
type: number
lora_network:
type: object
required: [name, source, type]
properties:
label:
type: string
model:
type: string
enum: [cloneofsimo, sd-scripts]
name:
type: string
source:
type: string
tokens:
type: array
items:
type: string
type:
type: string
const: lora
weight:
type: number
@ -46,8 +65,7 @@ $defs:
required: [name, source]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
$ref: "#/$defs/tensor_format"
half:
type: boolean
label:
@ -85,7 +103,7 @@ $defs:
inversions:
type: array
items:
$ref: "#/$defs/textual_inversion_network"
$ref: "#/$defs/embedding_network"
loras:
type: array
items:
@ -100,6 +118,7 @@ $defs:
panorama,
pix2pix,
txt2img,
txt2img-sdxl,
upscale,
]
vae:
@ -141,31 +160,6 @@ $defs:
source:
type: string
source_network:
type: object
required: [name, source, type]
properties:
format:
type: string
enum: [bin, ckpt, onnx, pt, pth, safetensors]
model:
type: string
enum: [
# inversion
concept,
embeddings,
# lora
cloneofsimo,
sd-scripts
]
name:
type: string
source:
type: string
type:
type: string
enum: [inversion, lora]
translation:
type: object
additionalProperties: False
@ -193,7 +187,9 @@ properties:
networks:
type: array
items:
$ref: "#/$defs/source_network"
oneOf:
- $ref: "#/$defs/lora_network"
- $ref: "#/$defs/embedding_network"
sources:
type: array
items:

74
api/scripts/onnx-lora.py Normal file
View File

@ -0,0 +1,74 @@
from argparse import ArgumentParser
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from os import path
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
write_external_data_tensors,
)
from onnxruntime import InferenceSession, SessionOptions
from logging import getLogger
from onnx_web.convert.utils import ConversionContext
logger = getLogger(__name__)
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args()
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras(
context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
)
if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -30,6 +30,10 @@ FAST_TEST = 10
SLOW_TEST = 25
VERY_SLOW_TEST = 75
STRICT_TEST = 1e-4
LOOSE_TEST = 1e-2
VERY_LOOSE_TEST = 0.025
def test_path(relpath: str) -> str:
return path.join(path.dirname(__file__), relpath)
@ -41,7 +45,7 @@ class TestCase:
name: str,
query: str,
max_attempts: int = FAST_TEST,
mse_threshold: float = 1e-4,
mse_threshold: float = STRICT_TEST,
source: Union[Image.Image, List[Image.Image]] = None,
mask: Image.Image = None,
) -> None:
@ -65,6 +69,7 @@ TEST_DATA = [
TestCase(
"txt2img-sd-v1-5-512-muffin-deis",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
mse_threshold=LOOSE_TEST,
),
TestCase(
"txt2img-sd-v1-5-512-muffin-dpm",
@ -73,10 +78,12 @@ TEST_DATA = [
TestCase(
"txt2img-sd-v1-5-512-muffin-heun",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
mse_threshold=LOOSE_TEST,
),
TestCase(
"txt2img-sd-v1-5-512-muffin-unipc",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi",
mse_threshold=LOOSE_TEST,
),
TestCase(
"txt2img-sd-v2-1-512-muffin",
@ -84,7 +91,7 @@ TEST_DATA = [
),
TestCase(
"txt2img-sd-v2-1-768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768",
max_attempts=SLOW_TEST,
),
TestCase(
@ -106,7 +113,7 @@ TEST_DATA = [
),
TestCase(
"img2img-sd-v1-5-256-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256",
source="txt2img-sd-v1-5-256-muffin-0",
),
TestCase(
@ -130,7 +137,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.025,
mse_threshold=VERY_LOOSE_TEST,
),
TestCase(
"outpaint-vertical-512",
@ -141,7 +148,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
mse_threshold=LOOSE_TEST,
),
TestCase(
"outpaint-horizontal-512",
@ -152,7 +159,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=SLOW_TEST,
mse_threshold=0.010,
mse_threshold=LOOSE_TEST,
),
TestCase(
"upscale-resrgan-x2-1024-muffin",
@ -229,7 +236,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025,
mse_threshold=VERY_LOOSE_TEST,
),
TestCase(
"outpaint-panorama-vertical-512",
@ -240,7 +247,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025,
mse_threshold=VERY_LOOSE_TEST,
),
TestCase(
"outpaint-panorama-horizontal-512",
@ -251,7 +258,7 @@ TEST_DATA = [
source="txt2img-sd-v1-5-512-muffin-0",
mask="mask-black",
max_attempts=VERY_SLOW_TEST,
mse_threshold=0.025,
mse_threshold=VERY_LOOSE_TEST,
),
TestCase(
"upscale-resrgan-x4-codeformer-2048-muffin",
@ -260,6 +267,7 @@ TEST_DATA = [
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
),
source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
),
TestCase(
"upscale-resrgan-x4-gfpgan-2048-muffin",
@ -268,6 +276,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
),
source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
),
TestCase(
"upscale-swinir-x4-codeformer-2048-muffin",
@ -276,6 +285,7 @@ TEST_DATA = [
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
),
source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
),
TestCase(
"upscale-swinir-x4-gfpgan-2048-muffin",
@ -284,6 +294,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
),
source="txt2img-sd-v1-5-512-muffin-0",
max_attempts=SLOW_TEST,
),
TestCase(
"upscale-sd-x4-codeformer-2048-muffin",
@ -305,18 +316,18 @@ TEST_DATA = [
),
TestCase(
"txt2img-panorama-1024x768-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true",
max_attempts=VERY_SLOW_TEST,
),
TestCase(
"img2img-panorama-1024x768-pumpkin",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true",
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true",
source="txt2img-panorama-1024x768-muffin-0",
max_attempts=VERY_SLOW_TEST,
),
TestCase(
"txt2img-sd-v1-5-tall-muffin",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768",
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768",
),
TestCase(
"upscale-resrgan-x4-tall-muffin",
@ -325,6 +336,7 @@ TEST_DATA = [
"&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0"
),
source="txt2img-sd-v1-5-tall-muffin-0",
max_attempts=SLOW_TEST,
),
# TODO: non-square controlnet
]
@ -335,6 +347,39 @@ class TestError(Exception):
return super().__str__()
class TestResult:
error: Optional[str]
mse: Optional[float]
name: str
passed: bool
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
self.error = error
self.mse = mse
self.name = name
self.passed = passed
def __repr__(self) -> str:
if self.passed:
if self.mse is not None:
return f"{self.name} ({self.mse})"
else:
return self.name
else:
if self.mse is not None:
return f"{self.name}: {self.error} ({self.mse})"
else:
return f"{self.name}: {self.error}"
@classmethod
def passed(self, name: str, mse = None):
return TestResult(name, mse=mse)
@classmethod
def failed(self, name: str, error: str, mse = None):
return TestResult(name, error=error, mse=mse, passed=False)
def parse_args(args: List[str]):
parser = ArgumentParser(
prog="onnx-web release tests",
@ -441,14 +486,14 @@ def run_test(
host: str,
test: TestCase,
mse_mult: float = 1.0,
) -> bool:
) -> TestResult:
"""
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
"""
keys = generate_images(host, test)
if keys is None:
raise ValueError("could not generate image")
return TestResult.failed(test.name, "could not generate image")
ready = False
for attempt in tqdm(range(test.max_attempts)):
@ -461,13 +506,13 @@ def run_test(
sleep(6)
if not ready:
raise ValueError("image was not ready in time")
return TestResult.failed(test.name, "image was not ready in time")
results = download_images(host, keys)
if results is None:
raise ValueError("could not download image")
if results is None or len(results) == 0:
return TestResult.failed(test.name, "could not download image")
passed = True
passed = False
for i in range(len(results)):
result = results[i]
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
@ -476,14 +521,19 @@ def run_test(
ref = Image.open(ref_name) if path.exists(ref_name) else None
mse = find_mse(result, ref)
threshold = test.mse_threshold * mse_mult
if mse < (test.mse_threshold * mse_mult):
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
if mse < threshold:
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
passed = True
else:
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)
passed = False
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
return passed
if passed:
return TestResult.passed(test.name)
else:
return TestResult.failed(test.name, "no images tested")
def main():
@ -504,24 +554,26 @@ def main():
passed = []
failed = []
for test in tests:
test_passed = False
result = None
for _i in range(3):
try:
logger.info("starting test: %s", test.name)
if run_test(args.host, test, mse_mult=args.mse):
result = run_test(args.host, test, mse_mult=args.mse)
if result.passed:
logger.info("test passed: %s", test.name)
test_passed = True
break
else:
logger.warning("test failed: %s", test.name)
except Exception:
logger.exception("error running test for %s", test.name)
result = TestResult.failed(test.name, "TODO: exception message")
if test_passed:
passed.append(test.name)
if result is not None:
if result.passed:
passed.append(result)
else:
failed.append(test.name)
failed.append(result)
logger.info("%s of %s tests passed", len(passed), len(tests))
failed = list(set(failed))

View File

View File

@ -0,0 +1,26 @@
import unittest
from onnx_web.chain.pipeline import ChainProgress
class ChainProgressTests(unittest.TestCase):
def test_accumulate_with_reset(self):
def parent(step, timestep, latents):
pass
progress = ChainProgress(parent)
progress(5, 1, None)
progress(0, 1, None)
progress(5, 1, None)
self.assertEqual(progress.get_total(), 10)
def test_start_value(self):
def parent(step, timestep, latents):
pass
progress = ChainProgress(parent, 5)
self.assertEqual(progress.get_total(), 5)
progress(10, 1, None)
self.assertEqual(progress.get_total(), 10)

View File

@ -0,0 +1,23 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendGridStage()
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
]
)
result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -0,0 +1,47 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.chain.result import StageResult
from onnx_web.params import ImageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
class BlendImg2ImgStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_stage(self):
stage = BlendImg2ImgStage()
params = ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"euler-a",
"an astronaut eating a hamburger",
3.0,
1,
1,
)
server = ServerContext(model_path="../models", output_path="../outputs")
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
)
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))

View File

@ -0,0 +1,23 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage
from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
)
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -0,0 +1,25 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_mask import BlendMaskStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
class BlendMaskStageTests(unittest.TestCase):
def test_empty(self):
stage = BlendMaskStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
stage_mask=Image.new("RGBA", (64, 64)),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,46 @@
import unittest
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import (
TEST_MODEL_CORRECTION_CODEFORMER,
test_device,
test_needs_models,
)
class CorrectCodeformerStageTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
stage = CorrectCodeformerStage()
sources = StageResult.empty()
result = stage.run(
worker,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,44 @@
import unittest
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/correction-gfpgan-v1-3"
class CorrectGFPGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
stage = CorrectGFPGANStage()
sources = StageResult.empty()
result = stage.run(
worker,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,24 @@
import unittest
from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceCropStageTests(unittest.TestCase):
def test_empty(self):
stage = ReduceCropStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,28 @@
import unittest
from PIL import Image
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceThumbnailStageTests(unittest.TestCase):
def test_empty(self):
stage_source = Image.new("RGB", (64, 64))
stage = ReduceThumbnailStage()
sources = StageResult.empty()
result = stage.run(
None,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
origin=Size(0, 0),
size=Size(128, 128),
stage_source=stage_source,
)
self.assertEqual(len(result), 0)

Some files were not shown because too many files have changed in this diff Show More