Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards
This commit is contained in:
commit
17f28aba62
19
README.md
19
README.md
|
@ -17,12 +17,14 @@ with a CPU fallback capable of running on laptop-class machines.
|
||||||
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
|
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
|
||||||
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
||||||
|
|
||||||
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png)
|
![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
This is an incomplete list of new and interesting features, with links to the user guide:
|
This is an incomplete list of new and interesting features, with links to the user guide:
|
||||||
|
|
||||||
|
- SDXL support
|
||||||
|
- LCM support
|
||||||
- hardware acceleration on both AMD and Nvidia
|
- hardware acceleration on both AMD and Nvidia
|
||||||
- tested on CUDA, DirectML, and ROCm
|
- tested on CUDA, DirectML, and ROCm
|
||||||
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
||||||
|
@ -37,6 +39,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
||||||
- [txt2img](docs/user-guide.md#txt2img-tab)
|
- [txt2img](docs/user-guide.md#txt2img-tab)
|
||||||
- [img2img](docs/user-guide.md#img2img-tab)
|
- [img2img](docs/user-guide.md#img2img-tab)
|
||||||
- [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload
|
- [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload
|
||||||
|
- [panorama](docs/user-guide.md#panorama-pipeline), for both SD v1.5 and SDXL
|
||||||
- [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration
|
- [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration
|
||||||
- [add and use your own models](docs/user-guide.md#adding-your-own-models)
|
- [add and use your own models](docs/user-guide.md#adding-your-own-models)
|
||||||
- [convert models from diffusers and SD checkpoints](docs/converting-models.md)
|
- [convert models from diffusers and SD checkpoints](docs/converting-models.md)
|
||||||
|
@ -45,20 +48,24 @@ This is an incomplete list of new and interesting features, with links to the us
|
||||||
- [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks)
|
- [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks)
|
||||||
- [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens)
|
- [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens)
|
||||||
- [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens)
|
- [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens)
|
||||||
|
- each layer of the embeddings can be controlled and used individually
|
||||||
- ControlNet
|
- ControlNet
|
||||||
- image filters for edge detection and other methods
|
- image filters for edge detection and other methods
|
||||||
- with ONNX acceleration
|
- with ONNX acceleration
|
||||||
- highres mode
|
- highres mode
|
||||||
- runs img2img on the results of the other pipelines
|
- runs img2img on the results of the other pipelines
|
||||||
- multiple iterations can produce 8k images and larger
|
- multiple iterations can produce 8k images and larger
|
||||||
|
- [multi-stage](docs/user-guide.md#prompt-stages) and [region prompts](docs/user-guide.md#region-tokens)
|
||||||
|
- seamlessly combine multiple prompts in the same image
|
||||||
|
- provide prompts for different areas in the image and blend them together
|
||||||
|
- change the prompt for highres mode and refine details without recursion
|
||||||
- infinite prompt length
|
- infinite prompt length
|
||||||
- [with long prompt weighting](docs/user-guide.md#long-prompt-weighting)
|
- [with long prompt weighting](docs/user-guide.md#long-prompt-weighting)
|
||||||
- expand and control Textual Inversions per-layer
|
|
||||||
- [image blending mode](docs/user-guide.md#blend-tab)
|
- [image blending mode](docs/user-guide.md#blend-tab)
|
||||||
- combine images from history
|
- combine images from history
|
||||||
- upscaling and face correction
|
- upscaling and correction
|
||||||
- upscaling with Real ESRGAN or Stable Diffusion
|
- upscaling with Real ESRGAN, SwinIR, and Stable Diffusion
|
||||||
- face correction with CodeFormer or GFPGAN
|
- face correction with CodeFormer and GFPGAN
|
||||||
- [API server can be run remotely](docs/server-admin.md)
|
- [API server can be run remotely](docs/server-admin.md)
|
||||||
- REST API can be served over HTTPS or HTTP
|
- REST API can be served over HTTPS or HTTP
|
||||||
- background processing for all image pipelines
|
- background processing for all image pipelines
|
||||||
|
@ -66,7 +73,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
||||||
- OCI containers provided
|
- OCI containers provided
|
||||||
- for all supported hardware accelerators
|
- for all supported hardware accelerators
|
||||||
- includes both the API and GUI bundle in a single container
|
- includes both the API and GUI bundle in a single container
|
||||||
- runs well on [RunPod](https://www.runpod.io/) and other GPU container hosting services
|
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
|
||||||
|
|
||||||
## Contents
|
## Contents
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
{
|
||||||
|
"python.testing.unittestArgs": [
|
||||||
|
"-v",
|
||||||
|
"-s",
|
||||||
|
"./tests",
|
||||||
|
"-p",
|
||||||
|
"test_*.py"
|
||||||
|
],
|
||||||
|
"python.testing.pytestEnabled": false,
|
||||||
|
"python.testing.unittestEnabled": true
|
||||||
|
}
|
17
api/Makefile
17
api/Makefile
|
@ -1,4 +1,4 @@
|
||||||
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload
|
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload style
|
||||||
|
|
||||||
onnx_env: ## create virtual env
|
onnx_env: ## create virtual env
|
||||||
python -v venv onnx_env
|
python -v venv onnx_env
|
||||||
|
@ -18,9 +18,10 @@ pip-dev: check-venv
|
||||||
|
|
||||||
test:
|
test:
|
||||||
python -m coverage erase
|
python -m coverage erase
|
||||||
python -m coverage run -m unittest discover -s tests/
|
python -m coverage run -m unittest discover -v -s tests/
|
||||||
python -m coverage html -i
|
python -m coverage html -i
|
||||||
python -m coverage xml -i
|
python -m coverage xml -i
|
||||||
|
python -m coverage report -i
|
||||||
|
|
||||||
package: package-dist package-upload
|
package: package-dist package-upload
|
||||||
|
|
||||||
|
@ -32,13 +33,21 @@ package-upload:
|
||||||
|
|
||||||
lint-check:
|
lint-check:
|
||||||
black --check onnx_web/
|
black --check onnx_web/
|
||||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
black --check tests/
|
||||||
flake8 onnx_web
|
flake8 onnx_web
|
||||||
|
flake8 tests
|
||||||
|
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||||
|
isort --check-only --skip __init__.py --filter-files tests
|
||||||
|
|
||||||
lint-fix:
|
lint-fix:
|
||||||
black onnx_web/
|
black onnx_web/
|
||||||
isort --skip __init__.py --filter-files onnx_web
|
black tests/
|
||||||
flake8 onnx_web
|
flake8 onnx_web
|
||||||
|
flake8 tests
|
||||||
|
isort --skip __init__.py --filter-files onnx_web
|
||||||
|
isort --skip __init__.py --filter-files tests
|
||||||
|
|
||||||
|
style: lint-fix
|
||||||
|
|
||||||
typecheck:
|
typecheck:
|
||||||
mypy onnx_web
|
mypy onnx_web
|
||||||
|
|
|
@ -1,45 +1,2 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageParams
|
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
from .stages import * # NOQA
|
||||||
from .blend_linear import BlendLinearStage
|
|
||||||
from .blend_mask import BlendMaskStage
|
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
|
||||||
from .correct_gfpgan import CorrectGFPGANStage
|
|
||||||
from .persist_disk import PersistDiskStage
|
|
||||||
from .persist_s3 import PersistS3Stage
|
|
||||||
from .reduce_crop import ReduceCropStage
|
|
||||||
from .reduce_thumbnail import ReduceThumbnailStage
|
|
||||||
from .source_noise import SourceNoiseStage
|
|
||||||
from .source_s3 import SourceS3Stage
|
|
||||||
from .source_txt2img import SourceTxt2ImgStage
|
|
||||||
from .source_url import SourceURLStage
|
|
||||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
|
||||||
from .upscale_highres import UpscaleHighresStage
|
|
||||||
from .upscale_outpaint import UpscaleOutpaintStage
|
|
||||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
|
||||||
from .upscale_simple import UpscaleSimpleStage
|
|
||||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
|
||||||
from .upscale_swinir import UpscaleSwinIRStage
|
|
||||||
|
|
||||||
CHAIN_STAGES = {
|
|
||||||
"blend-img2img": BlendImg2ImgStage,
|
|
||||||
"blend-inpaint": UpscaleOutpaintStage,
|
|
||||||
"blend-linear": BlendLinearStage,
|
|
||||||
"blend-mask": BlendMaskStage,
|
|
||||||
"correct-codeformer": CorrectCodeformerStage,
|
|
||||||
"correct-gfpgan": CorrectGFPGANStage,
|
|
||||||
"persist-disk": PersistDiskStage,
|
|
||||||
"persist-s3": PersistS3Stage,
|
|
||||||
"reduce-crop": ReduceCropStage,
|
|
||||||
"reduce-thumbnail": ReduceThumbnailStage,
|
|
||||||
"source-noise": SourceNoiseStage,
|
|
||||||
"source-s3": SourceS3Stage,
|
|
||||||
"source-txt2img": SourceTxt2ImgStage,
|
|
||||||
"source-url": SourceURLStage,
|
|
||||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
|
||||||
"upscale-highres": UpscaleHighresStage,
|
|
||||||
"upscale-outpaint": UpscaleOutpaintStage,
|
|
||||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
|
||||||
"upscale-simple": UpscaleSimpleStage,
|
|
||||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
|
||||||
"upscale-swinir": UpscaleSwinIRStage,
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,240 +1,39 @@
|
||||||
from datetime import timedelta
|
from typing import Optional
|
||||||
from logging import getLogger
|
|
||||||
from time import monotonic
|
|
||||||
from typing import Any, List, Optional, Tuple
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..errors import RetryException
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..output import save_image
|
from ..server.context import ServerContext
|
||||||
from ..params import ImageParams, StageParams
|
from ..worker.context import WorkerContext
|
||||||
from ..server import ServerContext
|
from .result import StageResult
|
||||||
from ..utils import is_debug, run_gc
|
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
|
||||||
from .stage import BaseStage
|
|
||||||
from .tile import needs_tile, process_tile_order
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
class BaseStage:
|
||||||
|
max_tile = SizeChart.auto
|
||||||
|
|
||||||
class ChainProgress:
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
|
||||||
self.parent = parent
|
|
||||||
self.step = start
|
|
||||||
self.total = 0
|
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
|
||||||
if step < self.step:
|
|
||||||
# accumulate on resets
|
|
||||||
self.total += self.step
|
|
||||||
|
|
||||||
self.step = step
|
|
||||||
self.parent(self.get_total(), timestep, latents)
|
|
||||||
|
|
||||||
def get_total(self) -> int:
|
|
||||||
return self.step + self.total
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_progress(cls, parent: ProgressCallback):
|
|
||||||
start = parent.step if hasattr(parent, "step") else 0
|
|
||||||
return ChainProgress(parent, start=start)
|
|
||||||
|
|
||||||
|
|
||||||
class ChainPipeline:
|
|
||||||
"""
|
|
||||||
Run many stages in series, passing the image results from each to the next, and processing
|
|
||||||
tiles as needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stages: Optional[List[PipelineStage]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new pipeline that will run the given stages.
|
|
||||||
"""
|
|
||||||
self.stages = list(stages or [])
|
|
||||||
|
|
||||||
def append(self, stage: Optional[PipelineStage]):
|
|
||||||
"""
|
|
||||||
Append an additional stage to this pipeline.
|
|
||||||
|
|
||||||
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
|
||||||
assemble the stage from loose arguments.
|
|
||||||
"""
|
|
||||||
if stage is not None:
|
|
||||||
self.stages.append(stage)
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
_server: ServerContext,
|
||||||
params: ImageParams,
|
_stage: StageParams,
|
||||||
sources: List[Image.Image],
|
_params: ImageParams,
|
||||||
callback: Optional[ProgressCallback],
|
_sources: StageResult,
|
||||||
**kwargs
|
*,
|
||||||
) -> List[Image.Image]:
|
stage_source: Optional[Image.Image] = None,
|
||||||
return self(
|
**kwargs,
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
) -> StageResult:
|
||||||
)
|
raise NotImplementedError() # noqa
|
||||||
|
|
||||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
def steps(
|
||||||
self.stages.append((callback, params, kwargs))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
_params: ImageParams,
|
||||||
server: ServerContext,
|
_size: Size,
|
||||||
params: ImageParams,
|
) -> int:
|
||||||
sources: List[Image.Image],
|
return 1 # noqa
|
||||||
callback: Optional[ProgressCallback] = None,
|
|
||||||
**pipeline_kwargs
|
|
||||||
) -> List[Image.Image]:
|
|
||||||
"""
|
|
||||||
DEPRECATED: use `run` instead
|
|
||||||
"""
|
|
||||||
if callback is not None:
|
|
||||||
callback = ChainProgress.from_progress(callback)
|
|
||||||
|
|
||||||
start = monotonic()
|
def outputs(
|
||||||
|
self,
|
||||||
if len(sources) > 0:
|
_params: ImageParams,
|
||||||
logger.info(
|
sources: int,
|
||||||
"running pipeline on %s source images",
|
) -> int:
|
||||||
len(sources),
|
return sources
|
||||||
)
|
|
||||||
else:
|
|
||||||
sources = [None]
|
|
||||||
logger.info("running pipeline without source images")
|
|
||||||
|
|
||||||
stage_sources = sources
|
|
||||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
|
||||||
name = stage_params.name or stage_pipe.__class__.__name__
|
|
||||||
kwargs = stage_kwargs or {}
|
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
|
||||||
logger.debug(
|
|
||||||
"running stage %s with %s source images, parameters: %s",
|
|
||||||
name,
|
|
||||||
len(stage_sources) - stage_sources.count(None),
|
|
||||||
kwargs.keys(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
|
||||||
must_tile = any(
|
|
||||||
[
|
|
||||||
needs_tile(
|
|
||||||
stage_pipe.max_tile,
|
|
||||||
stage_params.tile_size,
|
|
||||||
size=kwargs.get("size", None),
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
for source in stage_sources
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
tile = stage_params.tile_size
|
|
||||||
if stage_pipe.max_tile > 0:
|
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
|
||||||
|
|
||||||
if must_tile:
|
|
||||||
stage_outputs = []
|
|
||||||
for source in stage_sources:
|
|
||||||
logger.info(
|
|
||||||
"image larger than tile size of %s, tiling stage",
|
|
||||||
tile,
|
|
||||||
)
|
|
||||||
|
|
||||||
def stage_tile(
|
|
||||||
source_tile: Image.Image,
|
|
||||||
tile_mask: Image.Image,
|
|
||||||
dims: Tuple[int, int, int],
|
|
||||||
) -> Image.Image:
|
|
||||||
for i in range(worker.retries):
|
|
||||||
try:
|
|
||||||
output_tile = stage_pipe.run(
|
|
||||||
worker,
|
|
||||||
server,
|
|
||||||
stage_params,
|
|
||||||
params,
|
|
||||||
[source_tile],
|
|
||||||
tile_mask=tile_mask,
|
|
||||||
callback=callback,
|
|
||||||
dims=dims,
|
|
||||||
**kwargs,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-tile.png", output_tile)
|
|
||||||
|
|
||||||
return output_tile
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"error while running stage pipeline for tile, retry %s of 3",
|
|
||||||
i,
|
|
||||||
)
|
|
||||||
server.cache.clear()
|
|
||||||
run_gc([worker.get_device()])
|
|
||||||
worker.retries = worker.retries - (i + 1)
|
|
||||||
|
|
||||||
raise RetryException("exhausted retries on tile")
|
|
||||||
|
|
||||||
output = process_tile_order(
|
|
||||||
stage_params.tile_order,
|
|
||||||
source,
|
|
||||||
tile,
|
|
||||||
stage_params.outscale,
|
|
||||||
[stage_tile],
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
stage_outputs.append(output)
|
|
||||||
|
|
||||||
stage_sources = stage_outputs
|
|
||||||
else:
|
|
||||||
logger.debug("image within tile size of %s, running stage", tile)
|
|
||||||
for i in range(worker.retries):
|
|
||||||
try:
|
|
||||||
stage_outputs = stage_pipe.run(
|
|
||||||
worker,
|
|
||||||
server,
|
|
||||||
stage_params,
|
|
||||||
params,
|
|
||||||
stage_sources,
|
|
||||||
callback=callback,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
|
||||||
# does not like, so it throws
|
|
||||||
stage_sources = stage_outputs
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"error while running stage pipeline, retry %s of 3", i
|
|
||||||
)
|
|
||||||
server.cache.clear()
|
|
||||||
run_gc([worker.get_device()])
|
|
||||||
worker.retries = worker.retries - (i + 1)
|
|
||||||
|
|
||||||
if worker.retries <= 0:
|
|
||||||
raise RetryException("exhausted retries on stage")
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"finished stage %s with %s results",
|
|
||||||
name,
|
|
||||||
len(stage_sources),
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-stage.png", stage_sources[0])
|
|
||||||
|
|
||||||
end = monotonic()
|
|
||||||
duration = timedelta(seconds=(end - start))
|
|
||||||
logger.info(
|
|
||||||
"finished pipeline in %s with %s results",
|
|
||||||
duration,
|
|
||||||
len(stage_sources),
|
|
||||||
)
|
|
||||||
return stage_sources
|
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlendDenoiseStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
*,
|
||||||
|
strength: int = 3,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
logger.info("denoising source images")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for source in sources.as_numpy():
|
||||||
|
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
||||||
|
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
||||||
|
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
return StageResult(arrays=results)
|
|
@ -0,0 +1,62 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlendGridStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
*,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
# rows: Optional[List[str]] = None,
|
||||||
|
# columns: Optional[List[str]] = None,
|
||||||
|
# title: Optional[str] = None,
|
||||||
|
order: Optional[List[int]] = None,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
logger.info("combining source images using grid layout")
|
||||||
|
|
||||||
|
images = sources.as_image()
|
||||||
|
ref_image = images[0]
|
||||||
|
size = Size(*ref_image.size)
|
||||||
|
|
||||||
|
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
|
||||||
|
|
||||||
|
# TODO: labels
|
||||||
|
if order is None:
|
||||||
|
order = range(len(images))
|
||||||
|
|
||||||
|
for i in range(len(order)):
|
||||||
|
x = i % width
|
||||||
|
y = i // width
|
||||||
|
|
||||||
|
n = order[i]
|
||||||
|
output.paste(images[n], (x * size.width, y * size.height))
|
||||||
|
|
||||||
|
return StageResult(images=[*images, output])
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,13 +10,14 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BlendImg2ImgStage(BaseStage):
|
class BlendImg2ImgStage(BaseStage):
|
||||||
max_tile = SizeChart.unlimited
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
|
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
strength: float,
|
strength: float,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
prompt_index: Optional[int] = None,
|
prompt_index: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
|
|
||||||
# multi-stage prompting
|
# multi-stage prompting
|
||||||
|
@ -52,7 +53,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
pipe_params["strength"] = strength
|
pipe_params["strength"] = strength
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
if params.is_lpw():
|
if params.is_lpw():
|
||||||
logger.debug("using LPW pipeline for img2img")
|
logger.debug("using LPW pipeline for img2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
|
@ -81,11 +82,10 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# encode and record alternative prompts outside of LPW
|
# encode and record alternative prompts outside of LPW
|
||||||
prompt_embeds = encode_prompt(
|
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not params.is_xl():
|
if not params.is_xl():
|
||||||
|
prompt_embeds = encode_prompt(
|
||||||
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
|
)
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -102,4 +102,18 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
|
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
*args,
|
||||||
|
) -> int:
|
||||||
|
return params.steps # TODO: multiply by strength
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,13 +19,18 @@ class BlendLinearStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
alpha: float,
|
alpha: float,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
_callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return [Image.blend(source, stage_source, alpha) for source in sources]
|
return StageResult(
|
||||||
|
images=[
|
||||||
|
Image.blend(source, stage_source, alpha)
|
||||||
|
for source in sources.as_image()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,16 +21,17 @@ class BlendMaskStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
_callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("blending image using mask")
|
logger.info("blending image using mask")
|
||||||
|
|
||||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
# TODO: does this need an alpha channel?
|
||||||
|
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
||||||
mult_mask.alpha_composite(stage_mask)
|
mult_mask.alpha_composite(stage_mask)
|
||||||
mult_mask = mult_mask.convert("L")
|
mult_mask = mult_mask.convert("L")
|
||||||
|
|
||||||
|
@ -37,4 +39,9 @@ class BlendMaskStage(BaseStage):
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
return [Image.composite(stage_source, source, mult_mask) for source in sources]
|
return StageResult(
|
||||||
|
images=[
|
||||||
|
Image.composite(stage_source, source, mult_mask)
|
||||||
|
for source in sources.as_image()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
# TODO: rewrite and remove
|
# TODO: rewrite and remove
|
||||||
from codeformer import CodeFormer
|
from codeformer import CodeFormer
|
||||||
|
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
|
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
return [pipe(source) for source in sources]
|
return StageResult(images=[pipe(source) for source in sources.as_image()])
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -57,12 +57,12 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.correction_model is None:
|
if upscale.correction_model is None:
|
||||||
|
@ -73,16 +73,15 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
gfpgan = self.load(server, stage, upscale, device)
|
gfpgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = [
|
||||||
for source in sources:
|
gfpgan.enhance(
|
||||||
output = np.array(source)
|
source,
|
||||||
_, _, output = gfpgan.enhance(
|
|
||||||
output,
|
|
||||||
has_aligned=False,
|
has_aligned=False,
|
||||||
only_center_face=False,
|
only_center_face=False,
|
||||||
paste_back=True,
|
paste_back=True,
|
||||||
weight=upscale.face_strength,
|
weight=upscale.face_strength,
|
||||||
)
|
)
|
||||||
outputs.append(Image.fromarray(output, "RGB"))
|
for source in sources.as_numpy()
|
||||||
|
]
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..chain.base import ChainPipeline
|
|
||||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||||
from ..chain.upscale import stage_upscale_correction
|
from ..chain.upscale import stage_upscale_correction
|
||||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from .pipeline import ChainPipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ def stage_highres(
|
||||||
outscale=highres.scale,
|
outscale=highres.scale,
|
||||||
),
|
),
|
||||||
chain=chain,
|
chain=chain,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("using simple upscaling for highres")
|
logger.debug("using simple upscaling for highres")
|
||||||
|
@ -51,14 +51,14 @@ def stage_highres(
|
||||||
UpscaleSimpleStage(),
|
UpscaleSimpleStage(),
|
||||||
stage,
|
stage,
|
||||||
method=highres.method,
|
method=highres.method,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
|
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
stage.with_args(outscale=1),
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
prompt_index=prompt_index + i,
|
prompt_index=prompt_index + i,
|
||||||
strength=highres.strength,
|
strength=highres.strength,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,33 +1,38 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PersistDiskStage(BaseStage):
|
class PersistDiskStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
_worker: WorkerContext,
|
_worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
output: str,
|
output: List[str],
|
||||||
stage_source: Image.Image,
|
size: Optional[Size] = None,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
for source in sources:
|
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||||
# TODO: append index to output name
|
|
||||||
dest = save_image(server, output, source, params=params)
|
for source, name in zip(sources.as_image(), output):
|
||||||
|
dest = save_image(server, name, source, params=params, size=size)
|
||||||
logger.info("saved image to %s", dest)
|
logger.info("saved image to %s", dest)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
|
@ -8,7 +8,8 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,26 +21,26 @@ class PersistS3Stage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
output: str,
|
output: List[str],
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
for source in sources:
|
for source, name in zip(sources.as_image(), output):
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source.save(data, format=server.image_format)
|
source.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s3.upload_fileobj(data, bucket, output)
|
s3.upload_fileobj(data, bucket, name)
|
||||||
logger.info("saved image to s3://%s/%s", bucket, output)
|
logger.info("saved image to s3://%s/%s", bucket, name)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error saving image to S3")
|
logger.exception("error saving image to S3")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,281 @@
|
||||||
|
from datetime import timedelta
|
||||||
|
from logging import getLogger
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..errors import CancelledException, RetryException
|
||||||
|
from ..output import save_image
|
||||||
|
from ..params import ImageParams, Size, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..utils import is_debug, run_gc
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
from .tile import needs_tile, process_tile_order
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class ChainProgress:
|
||||||
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
|
self.parent = parent
|
||||||
|
self.step = start
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
|
if step < self.step:
|
||||||
|
# accumulate on resets
|
||||||
|
self.total += self.step
|
||||||
|
|
||||||
|
self.step = step
|
||||||
|
self.parent(self.get_total(), timestep, latents)
|
||||||
|
|
||||||
|
def get_total(self) -> int:
|
||||||
|
return self.step + self.total
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_progress(cls, parent: ProgressCallback):
|
||||||
|
start = parent.step if hasattr(parent, "step") else 0
|
||||||
|
return ChainProgress(parent, start=start)
|
||||||
|
|
||||||
|
|
||||||
|
class ChainPipeline:
|
||||||
|
"""
|
||||||
|
Run many stages in series, passing the image results from each to the next, and processing
|
||||||
|
tiles as needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stages: Optional[List[PipelineStage]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new pipeline that will run the given stages.
|
||||||
|
"""
|
||||||
|
self.stages = list(stages or [])
|
||||||
|
|
||||||
|
def append(self, stage: Optional[PipelineStage]):
|
||||||
|
"""
|
||||||
|
Append an additional stage to this pipeline.
|
||||||
|
|
||||||
|
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
||||||
|
assemble the stage from loose arguments.
|
||||||
|
"""
|
||||||
|
if stage is not None:
|
||||||
|
self.stages.append(stage)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
callback: Optional[ProgressCallback],
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
result = self(
|
||||||
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
|
)
|
||||||
|
return result.as_image()
|
||||||
|
|
||||||
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
|
self.stages.append((callback, params, kwargs))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def steps(self, params: ImageParams, size: Size) -> int:
|
||||||
|
steps = 0
|
||||||
|
for callback, _params, kwargs in self.stages:
|
||||||
|
steps += callback.steps(kwargs.get("params", params), size)
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
def outputs(self, params: ImageParams, sources: int) -> int:
|
||||||
|
outputs = sources
|
||||||
|
for callback, _params, kwargs in self.stages:
|
||||||
|
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
**pipeline_kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
"""
|
||||||
|
DEPRECATED: use `.run()` instead
|
||||||
|
"""
|
||||||
|
if callback is None:
|
||||||
|
callback = worker.get_progress_callback()
|
||||||
|
else:
|
||||||
|
callback = ChainProgress.from_progress(callback)
|
||||||
|
|
||||||
|
start = monotonic()
|
||||||
|
|
||||||
|
if len(sources) > 0:
|
||||||
|
logger.info(
|
||||||
|
"running pipeline on %s source images",
|
||||||
|
len(sources),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("running pipeline without source images")
|
||||||
|
|
||||||
|
stage_sources = sources
|
||||||
|
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||||
|
name = stage_params.name or stage_pipe.__class__.__name__
|
||||||
|
kwargs = stage_kwargs or {}
|
||||||
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
logger.debug(
|
||||||
|
"running stage %s with %s source images, parameters: %s",
|
||||||
|
name,
|
||||||
|
len(stage_sources),
|
||||||
|
kwargs.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
|
per_stage_params = params
|
||||||
|
if "params" in kwargs:
|
||||||
|
per_stage_params = kwargs["params"]
|
||||||
|
kwargs.pop("params")
|
||||||
|
|
||||||
|
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
||||||
|
must_tile = has_mask(stage_kwargs) or any(
|
||||||
|
[
|
||||||
|
needs_tile(
|
||||||
|
stage_pipe.max_tile,
|
||||||
|
stage_params.tile_size,
|
||||||
|
size=kwargs.get("size", None),
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
for source in stage_sources.as_image()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
tile = stage_params.tile_size
|
||||||
|
if stage_pipe.max_tile > 0:
|
||||||
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
|
if must_tile:
|
||||||
|
logger.info(
|
||||||
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
|
tile,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stage_tile(
|
||||||
|
source_tile: List[Image.Image],
|
||||||
|
tile_mask: Image.Image,
|
||||||
|
dims: Tuple[int, int, int],
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
for _i in range(worker.retries):
|
||||||
|
try:
|
||||||
|
tile_result = stage_pipe.run(
|
||||||
|
worker,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
per_stage_params,
|
||||||
|
StageResult(images=source_tile),
|
||||||
|
tile_mask=tile_mask,
|
||||||
|
callback=callback,
|
||||||
|
dims=dims,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
|
return tile_result
|
||||||
|
except CancelledException as err:
|
||||||
|
worker.retries = 0
|
||||||
|
logger.exception("job was cancelled while tiling")
|
||||||
|
raise err
|
||||||
|
except Exception:
|
||||||
|
worker.retries = worker.retries - 1
|
||||||
|
logger.exception(
|
||||||
|
"error while running stage pipeline for tile, %s retries left",
|
||||||
|
worker.retries,
|
||||||
|
)
|
||||||
|
server.cache.clear()
|
||||||
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
|
raise RetryException("exhausted retries on tile")
|
||||||
|
|
||||||
|
stage_results = process_tile_order(
|
||||||
|
stage_params.tile_order,
|
||||||
|
stage_sources,
|
||||||
|
tile,
|
||||||
|
stage_params.outscale,
|
||||||
|
[stage_tile],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_sources = StageResult(images=stage_results)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"image does not contain sources and is within tile size of %s, running stage",
|
||||||
|
tile,
|
||||||
|
)
|
||||||
|
for _i in range(worker.retries):
|
||||||
|
try:
|
||||||
|
stage_result = stage_pipe.run(
|
||||||
|
worker,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
per_stage_params,
|
||||||
|
stage_sources,
|
||||||
|
callback=callback,
|
||||||
|
dims=(0, 0, tile),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
||||||
|
# does not like, so it throws
|
||||||
|
stage_sources = stage_result
|
||||||
|
break
|
||||||
|
except CancelledException as err:
|
||||||
|
worker.retries = 0
|
||||||
|
logger.exception("job was cancelled during stage")
|
||||||
|
raise err
|
||||||
|
except Exception:
|
||||||
|
worker.retries = worker.retries - 1
|
||||||
|
logger.exception(
|
||||||
|
"error while running stage pipeline, %s retries left",
|
||||||
|
worker.retries,
|
||||||
|
)
|
||||||
|
server.cache.clear()
|
||||||
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
|
if worker.retries <= 0:
|
||||||
|
raise RetryException("exhausted retries on stage")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"finished stage %s with %s results",
|
||||||
|
name,
|
||||||
|
len(stage_sources),
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
for j, image in enumerate(stage_sources.as_image()):
|
||||||
|
save_image(server, f"last-stage-{j}.png", image)
|
||||||
|
|
||||||
|
end = monotonic()
|
||||||
|
duration = timedelta(seconds=(end - start))
|
||||||
|
logger.info(
|
||||||
|
"finished pipeline in %s with %s results",
|
||||||
|
duration,
|
||||||
|
len(stage_sources),
|
||||||
|
)
|
||||||
|
return stage_sources
|
||||||
|
|
||||||
|
|
||||||
|
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
def has_mask(args: List[str]) -> bool:
|
||||||
|
return any([key in args for key in MASK_KEYS])
|
|
@ -1,12 +1,13 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
origin: Size,
|
origin: Size,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||||
logger.info(
|
logger.info(
|
||||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||||
)
|
)
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,15 +18,15 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
stage_source: Image.Image,
|
stage_source: Image.Image,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
image = source.copy()
|
image = source.copy()
|
||||||
|
|
||||||
image = image.thumbnail((size.width, size.height))
|
image = image.thumbnail((size.width, size.height))
|
||||||
|
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class StageResult:
|
||||||
|
"""
|
||||||
|
Chain pipeline stage result.
|
||||||
|
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||||
|
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||||
|
they are expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
arrays: Optional[List[np.ndarray]]
|
||||||
|
images: Optional[List[Image.Image]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty():
|
||||||
|
return StageResult(images=[])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_arrays(arrays: List[np.ndarray]):
|
||||||
|
return StageResult(arrays=arrays)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_images(images: List[Image.Image]):
|
||||||
|
return StageResult(images=images)
|
||||||
|
|
||||||
|
def __init__(self, arrays=None, images=None) -> None:
|
||||||
|
if arrays is not None and images is not None:
|
||||||
|
raise ValueError("stages must only return one type of result")
|
||||||
|
elif arrays is None and images is None:
|
||||||
|
raise ValueError("stages must return results")
|
||||||
|
|
||||||
|
self.arrays = arrays
|
||||||
|
self.images = images
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.arrays is not None:
|
||||||
|
return len(self.arrays)
|
||||||
|
elif self.images is not None:
|
||||||
|
return len(self.images)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def as_numpy(self) -> List[np.ndarray]:
|
||||||
|
if self.arrays is not None:
|
||||||
|
return self.arrays
|
||||||
|
elif self.images is not None:
|
||||||
|
return [np.array(i) for i in self.images]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def as_image(self) -> List[Image.Image]:
|
||||||
|
if self.images is not None:
|
||||||
|
return self.images
|
||||||
|
elif self.arrays is not None:
|
||||||
|
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def shape_mode(arr: np.ndarray) -> str:
|
||||||
|
if len(arr.shape) != 3:
|
||||||
|
raise ValueError("unknown array format")
|
||||||
|
|
||||||
|
if arr.shape[-1] == 3:
|
||||||
|
return "RGB"
|
||||||
|
elif arr.shape[-1] == 4:
|
||||||
|
return "RGBA"
|
||||||
|
|
||||||
|
raise ValueError("unknown image format")
|
|
@ -1,12 +1,13 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, List
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, Size, StageParams
|
from ..params import ImageParams, Size, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,25 +19,34 @@ class SourceNoiseStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
noise_source: Callable,
|
noise_source: Callable,
|
||||||
stage_source: Image.Image,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("generating image from noise source")
|
logger.info("generating image from noise source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
logger.warning(
|
logger.info(
|
||||||
"source images were passed to a noise stage and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
|
||||||
|
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||||
|
for source in sources.as_image():
|
||||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
||||||
|
|
|
@ -8,7 +8,8 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,18 +21,23 @@ class SourceS3Stage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
_sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
source_keys: List[str],
|
source_keys: List[str],
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: Optional[str] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
profile_name: Optional[str] = None,
|
profile_name: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
outputs = []
|
if len(sources) > 0:
|
||||||
|
logger.info(
|
||||||
|
"source images were passed to a source stage, new images will be appended"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = sources.as_image()
|
||||||
for key in source_keys:
|
for key in source_keys:
|
||||||
try:
|
try:
|
||||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||||
|
@ -43,4 +49,11 @@ class SourceS3Stage(BaseStage):
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error loading image from S3")
|
logger.exception("error loading image from S3")
|
||||||
|
|
||||||
return outputs
|
return StageResult(outputs)
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1 # TODO: len(source_keys)
|
||||||
|
|
|
@ -3,26 +3,28 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
from ..constants import LATENT_FACTOR
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
from ..diffusers.utils import (
|
from ..diffusers.utils import (
|
||||||
encode_prompt,
|
encode_prompt,
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
get_tile_latents,
|
get_tile_latents,
|
||||||
parse_prompt,
|
parse_prompt,
|
||||||
|
parse_reseed,
|
||||||
slice_prompt,
|
slice_prompt,
|
||||||
)
|
)
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SourceTxt2ImgStage(BaseStage):
|
class SourceTxt2ImgStage(BaseStage):
|
||||||
max_tile = SizeChart.unlimited
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
|
@ -30,15 +32,15 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_source: Image.Image,
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int] = None,
|
||||||
size: Size,
|
size: Size,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
latents: Optional[np.ndarray] = None,
|
latents: Optional[np.ndarray] = None,
|
||||||
prompt_index: Optional[int] = None,
|
prompt_index: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
size = size.with_args(**kwargs)
|
size = size.with_args(**kwargs)
|
||||||
|
|
||||||
|
@ -47,31 +49,58 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
|
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
"generating image using txt2img, %s steps of %s: %s",
|
||||||
|
params.steps,
|
||||||
|
params.model,
|
||||||
|
params.prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stage_source" in kwargs:
|
if len(sources):
|
||||||
logger.warning(
|
logger.info(
|
||||||
"a source image was passed to a txt2img stage, and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||||
params
|
params
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.is_xl():
|
if params.is_panorama() or params.is_xl():
|
||||||
tile_size = max(stage.tile_size, params.tiles)
|
tile_size = max(stage.tile_size, params.unet_tile)
|
||||||
else:
|
else:
|
||||||
tile_size = params.tiles
|
tile_size = params.unet_tile
|
||||||
|
|
||||||
# this works for panorama as well, because tile_size is already max(tile_size, *size)
|
# this works for panorama as well, because tile_size is already max(tile_size, *size)
|
||||||
latent_size = size.min(tile_size, tile_size)
|
latent_size = size.min(tile_size, tile_size)
|
||||||
|
|
||||||
# generate new latents or slice existing
|
# generate new latents or slice existing
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
latents = get_latents_from_seed(int(params.seed), latent_size, params.batch)
|
||||||
else:
|
else:
|
||||||
latents = get_tile_latents(latents, params.seed, latent_size, dims)
|
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||||
|
|
||||||
|
# reseed latents as needed
|
||||||
|
reseed_rng = np.random.RandomState(params.seed)
|
||||||
|
prompt, reseed = parse_reseed(prompt)
|
||||||
|
for top, left, bottom, right, region_seed in reseed:
|
||||||
|
if region_seed == -1:
|
||||||
|
region_seed = reseed_rng.random_integers(2**32 - 1)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||||
|
top,
|
||||||
|
left,
|
||||||
|
bottom,
|
||||||
|
right,
|
||||||
|
region_seed,
|
||||||
|
)
|
||||||
|
latents[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
|
||||||
|
left // LATENT_FACTOR : right // LATENT_FACTOR,
|
||||||
|
] = get_latents_from_seed(
|
||||||
|
region_seed, Size(right - left, bottom - top), params.batch
|
||||||
|
)
|
||||||
|
|
||||||
pipe_type = params.get_valid_pipeline("txt2img")
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
@ -79,7 +108,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,11 +130,14 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# encode and record alternative prompts outside of LPW
|
# encode and record alternative prompts outside of LPW
|
||||||
prompt_embeds = encode_prompt(
|
if params.is_panorama() or params.is_xl():
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
logger.debug(
|
||||||
)
|
"prompt alternatives are not supported for panorama or SDXL"
|
||||||
|
)
|
||||||
if not params.is_xl():
|
else:
|
||||||
|
prompt_embeds = encode_prompt(
|
||||||
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
|
)
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -123,4 +155,21 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.images
|
outputs = sources.as_image()
|
||||||
|
outputs.extend(result.images)
|
||||||
|
logger.debug("produced %s outputs", len(outputs))
|
||||||
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
) -> int:
|
||||||
|
return params.steps
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -8,7 +8,8 @@ from PIL import Image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,20 +21,20 @@ class SourceURLStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
source_urls: List[str],
|
source_urls: List[str],
|
||||||
stage_source: Image.Image,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("loading image from URL source")
|
logger.info("loading image from URL source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
logger.warning(
|
logger.info(
|
||||||
"a source image was passed to a source stage, and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = sources.as_image()
|
||||||
for url in source_urls:
|
for url in source_urls:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
output = Image.open(BytesIO(response.content))
|
output = Image.open(BytesIO(response.content))
|
||||||
|
@ -41,4 +42,11 @@ class SourceURLStage(BaseStage):
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
|
def outputs(
|
||||||
|
self,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: int,
|
||||||
|
) -> int:
|
||||||
|
return sources + 1
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
|
||||||
from ..server.context import ServerContext
|
|
||||||
from ..worker.context import WorkerContext
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStage:
|
|
||||||
max_tile = SizeChart.auto
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
worker: WorkerContext,
|
|
||||||
server: ServerContext,
|
|
||||||
stage: StageParams,
|
|
||||||
_params: ImageParams,
|
|
||||||
sources: List[Image.Image],
|
|
||||||
*args,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> List[Image.Image]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def steps(
|
|
||||||
self,
|
|
||||||
_params: ImageParams,
|
|
||||||
size: Size,
|
|
||||||
) -> int:
|
|
||||||
raise NotImplementedError()
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from .base import BaseStage
|
||||||
|
from .blend_denoise import BlendDenoiseStage
|
||||||
|
from .blend_grid import BlendGridStage
|
||||||
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
|
from .blend_linear import BlendLinearStage
|
||||||
|
from .blend_mask import BlendMaskStage
|
||||||
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
from .correct_gfpgan import CorrectGFPGANStage
|
||||||
|
from .persist_disk import PersistDiskStage
|
||||||
|
from .persist_s3 import PersistS3Stage
|
||||||
|
from .reduce_crop import ReduceCropStage
|
||||||
|
from .reduce_thumbnail import ReduceThumbnailStage
|
||||||
|
from .source_noise import SourceNoiseStage
|
||||||
|
from .source_s3 import SourceS3Stage
|
||||||
|
from .source_txt2img import SourceTxt2ImgStage
|
||||||
|
from .source_url import SourceURLStage
|
||||||
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
|
from .upscale_highres import UpscaleHighresStage
|
||||||
|
from .upscale_outpaint import UpscaleOutpaintStage
|
||||||
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||||
|
from .upscale_simple import UpscaleSimpleStage
|
||||||
|
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||||
|
from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
CHAIN_STAGES = {
|
||||||
|
"blend-denoise": BlendDenoiseStage,
|
||||||
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
|
"blend-inpaint": UpscaleOutpaintStage,
|
||||||
|
"blend-grid": BlendGridStage,
|
||||||
|
"blend-linear": BlendLinearStage,
|
||||||
|
"blend-mask": BlendMaskStage,
|
||||||
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
|
"correct-gfpgan": CorrectGFPGANStage,
|
||||||
|
"persist-disk": PersistDiskStage,
|
||||||
|
"persist-s3": PersistS3Stage,
|
||||||
|
"reduce-crop": ReduceCropStage,
|
||||||
|
"reduce-thumbnail": ReduceThumbnailStage,
|
||||||
|
"source-noise": SourceNoiseStage,
|
||||||
|
"source-s3": SourceS3Stage,
|
||||||
|
"source-txt2img": SourceTxt2ImgStage,
|
||||||
|
"source-url": SourceURLStage,
|
||||||
|
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||||
|
"upscale-highres": UpscaleHighresStage,
|
||||||
|
"upscale-outpaint": UpscaleOutpaintStage,
|
||||||
|
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||||
|
"upscale-simple": UpscaleSimpleStage,
|
||||||
|
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||||
|
"upscale-swinir": UpscaleSwinIRStage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_stage(name: str, stage: BaseStage) -> bool:
|
||||||
|
global CHAIN_STAGES
|
||||||
|
|
||||||
|
if name in CHAIN_STAGES:
|
||||||
|
logger.warning("cannot replace stage: %s", name)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
CHAIN_STAGES[name] = stage
|
||||||
|
return True
|
|
@ -2,13 +2,14 @@ import itertools
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import List, Optional, Protocol, Tuple
|
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..image.noise_source import noise_source_histogram
|
from ..image.noise_source import noise_source_histogram
|
||||||
from ..params import Size, TileOrder
|
from ..params import Size, TileOrder
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
# from skimage.exposure import match_histograms
|
# from skimage.exposure import match_histograms
|
||||||
|
|
||||||
|
@ -16,12 +17,15 @@ from ..params import Size, TileOrder
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
|
||||||
class TileCallback(Protocol):
|
class TileCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
Definition for a tile job function.
|
Definition for a tile job function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
|
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||||
"""
|
"""
|
||||||
Run this stage against a single tile.
|
Run this stage against a single tile.
|
||||||
"""
|
"""
|
||||||
|
@ -32,6 +36,9 @@ def complete_tile(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
tile: int,
|
tile: int,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
TODO: clean up
|
||||||
|
"""
|
||||||
if source is None:
|
if source is None:
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
@ -50,6 +57,12 @@ def needs_tile(
|
||||||
source: Optional[Image.Image] = None,
|
source: Optional[Image.Image] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
tile = min(max_tile, stage_tile)
|
tile = min(max_tile, stage_tile)
|
||||||
|
logger.trace(
|
||||||
|
"checking image tile dimensions: %s, %s, %s",
|
||||||
|
tile,
|
||||||
|
source.width > tile or source.height > tile if source is not None else False,
|
||||||
|
size.width > tile or size.height > tile if size is not None else False,
|
||||||
|
)
|
||||||
|
|
||||||
if source is not None:
|
if source is not None:
|
||||||
return source.width > tile or source.height > tile
|
return source.width > tile or source.height > tile
|
||||||
|
@ -60,7 +73,7 @@ def needs_tile(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_tile_grads(
|
def make_tile_grads(
|
||||||
left: int,
|
left: int,
|
||||||
top: int,
|
top: int,
|
||||||
tile: int,
|
tile: int,
|
||||||
|
@ -85,6 +98,60 @@ def get_tile_grads(
|
||||||
return (grad_x, grad_y)
|
return (grad_x, grad_y)
|
||||||
|
|
||||||
|
|
||||||
|
def make_tile_mask(
|
||||||
|
shape: Any,
|
||||||
|
tile: Tuple[int, int],
|
||||||
|
overlap: float,
|
||||||
|
edges: Tuple[bool, bool, bool, bool],
|
||||||
|
) -> np.ndarray:
|
||||||
|
mask = np.ones(shape)
|
||||||
|
|
||||||
|
tile_h, tile_w = tile
|
||||||
|
|
||||||
|
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
|
||||||
|
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
|
||||||
|
|
||||||
|
# sort gradient points
|
||||||
|
p1_h = adj_tile_h - 1
|
||||||
|
p2_h = tile_h - adj_tile_h
|
||||||
|
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
|
||||||
|
|
||||||
|
p1_w = adj_tile_w - 1
|
||||||
|
p2_w = tile_w - adj_tile_w
|
||||||
|
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
|
||||||
|
|
||||||
|
# build gradients
|
||||||
|
edge_t, edge_l, edge_b, edge_r = edges
|
||||||
|
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
|
||||||
|
int(not edge_t),
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
int(not edge_b),
|
||||||
|
]
|
||||||
|
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
|
||||||
|
|
||||||
|
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
|
||||||
|
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
|
||||||
|
|
||||||
|
mask = ((mask * mult_x).T * mult_y).T
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
return image.shape[-1]
|
||||||
|
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
return 4
|
||||||
|
elif image.mode == "RGB":
|
||||||
|
return 3
|
||||||
|
elif image.mode == "L":
|
||||||
|
return 1
|
||||||
|
|
||||||
|
raise ValueError("unknown image format")
|
||||||
|
|
||||||
|
|
||||||
def blend_tiles(
|
def blend_tiles(
|
||||||
tiles: List[Tuple[int, int, Image.Image]],
|
tiles: List[Tuple[int, int, Image.Image]],
|
||||||
scale: int,
|
scale: int,
|
||||||
|
@ -98,23 +165,24 @@ def blend_tiles(
|
||||||
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
|
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_size = (height * scale, width * scale, 3)
|
channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
|
||||||
|
scaled_size = (height * scale, width * scale, channels)
|
||||||
|
|
||||||
count = np.zeros(scaled_size)
|
count = np.zeros(scaled_size)
|
||||||
value = np.zeros(scaled_size)
|
value = np.zeros(scaled_size)
|
||||||
|
|
||||||
for left, top, tile_image in tiles:
|
for left, top, tile_image in tiles:
|
||||||
# histogram equalization
|
|
||||||
equalized = np.array(tile_image).astype(np.float32)
|
equalized = np.array(tile_image).astype(np.float32)
|
||||||
mask = np.ones_like(equalized[:, :, 0])
|
mask = np.ones_like(equalized[:, :, 0])
|
||||||
|
|
||||||
if adj_tile < tile:
|
if adj_tile < tile:
|
||||||
# sort gradient points
|
# sort gradient points
|
||||||
p1 = adj_tile * scale
|
p1 = (adj_tile * scale) - 1
|
||||||
p2 = (tile - adj_tile) * scale
|
p2 = (tile - adj_tile - 1) * scale
|
||||||
points = [0, min(p1, p2), max(p1, p2), tile * scale]
|
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
||||||
|
|
||||||
# gradient blending
|
# gradient blending
|
||||||
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
|
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
|
||||||
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
||||||
|
|
||||||
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
||||||
|
@ -169,7 +237,7 @@ def blend_tiles(
|
||||||
margin_left : equalized.shape[1] + margin_right,
|
margin_left : equalized.shape[1] + margin_right,
|
||||||
np.newaxis,
|
np.newaxis,
|
||||||
],
|
],
|
||||||
3,
|
channels,
|
||||||
axis=2,
|
axis=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -178,60 +246,18 @@ def blend_tiles(
|
||||||
return Image.fromarray(np.uint8(pixels))
|
return Image.fromarray(np.uint8(pixels))
|
||||||
|
|
||||||
|
|
||||||
def process_tile_grid(
|
def process_tile_stack(
|
||||||
source: Image.Image,
|
stack: StageResult,
|
||||||
tile: int,
|
|
||||||
scale: int,
|
|
||||||
filters: List[TileCallback],
|
|
||||||
overlap: float = 0.0,
|
|
||||||
**kwargs,
|
|
||||||
) -> Image.Image:
|
|
||||||
width, height = kwargs.get("size", source.size if source else None)
|
|
||||||
|
|
||||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
|
||||||
tiles_x = ceil(width / adj_tile)
|
|
||||||
tiles_y = ceil(height / adj_tile)
|
|
||||||
total = tiles_x * tiles_y
|
|
||||||
logger.debug(
|
|
||||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
|
||||||
total,
|
|
||||||
tiles_x,
|
|
||||||
tiles_y,
|
|
||||||
adj_tile,
|
|
||||||
overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
|
||||||
|
|
||||||
for y in range(tiles_y):
|
|
||||||
for x in range(tiles_x):
|
|
||||||
idx = (y * tiles_x) + x
|
|
||||||
left = x * adj_tile
|
|
||||||
top = y * adj_tile
|
|
||||||
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
|
||||||
|
|
||||||
tile_image = (
|
|
||||||
source.crop((left, top, left + tile, top + tile)) if source else None
|
|
||||||
)
|
|
||||||
tile_image = complete_tile(tile_image, tile)
|
|
||||||
|
|
||||||
for filter in filters:
|
|
||||||
tile_image = filter(tile_image, (left, top, tile))
|
|
||||||
|
|
||||||
tiles.append((left, top, tile_image))
|
|
||||||
|
|
||||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
|
||||||
|
|
||||||
|
|
||||||
def process_tile_spiral(
|
|
||||||
source: Image.Image,
|
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
|
tile_generator: TileGenerator,
|
||||||
overlap: float = 0.5,
|
overlap: float = 0.5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> List[Image.Image]:
|
||||||
width, height = kwargs.get("size", source.size if source else None)
|
sources = stack.as_image()
|
||||||
|
|
||||||
|
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||||
mask = kwargs.get("mask", None)
|
mask = kwargs.get("mask", None)
|
||||||
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
||||||
fill_color = kwargs.get("fill_color", None)
|
fill_color = kwargs.get("fill_color", None)
|
||||||
|
@ -239,18 +265,10 @@ def process_tile_spiral(
|
||||||
tile_mask = None
|
tile_mask = None
|
||||||
|
|
||||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||||
|
tile_coords = tile_generator(width, height, tile, overlap)
|
||||||
|
single_tile = len(tile_coords) == 1
|
||||||
|
|
||||||
# tile tuples is source, multiply by scale for dest
|
for counter, (left, top) in enumerate(tile_coords):
|
||||||
counter = 0
|
|
||||||
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
|
|
||||||
|
|
||||||
if len(tile_coords) == 1:
|
|
||||||
single_tile = True
|
|
||||||
else:
|
|
||||||
single_tile = False
|
|
||||||
|
|
||||||
for left, top in tile_coords:
|
|
||||||
counter += 1
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
||||||
)
|
)
|
||||||
|
@ -274,26 +292,36 @@ def process_tile_spiral(
|
||||||
needs_margin = True
|
needs_margin = True
|
||||||
bottom_margin = height - bottom
|
bottom_margin = height - bottom
|
||||||
|
|
||||||
# if no source given, we don't have a source image
|
if single_tile:
|
||||||
if not source:
|
logger.debug("using single tile")
|
||||||
tile_image = None
|
tile_stack = sources
|
||||||
|
if mask:
|
||||||
|
tile_mask = mask
|
||||||
elif needs_margin:
|
elif needs_margin:
|
||||||
# in the special case where the image is smaller than the specified tile size, just use the image
|
logger.debug(
|
||||||
if single_tile:
|
"tiling with added margins: %s, %s, %s, %s",
|
||||||
logger.debug("creating and processing single-tile subtile")
|
left_margin,
|
||||||
tile_image = source
|
top_margin,
|
||||||
if mask:
|
right_margin,
|
||||||
tile_mask = mask
|
bottom_margin,
|
||||||
# otherwise use add histogram noise outside of the image border
|
)
|
||||||
else:
|
tile_stack = add_margin(
|
||||||
logger.debug(
|
stack.as_image(),
|
||||||
"tiling and adding margins: %s, %s, %s, %s",
|
left,
|
||||||
left_margin,
|
top,
|
||||||
top_margin,
|
right,
|
||||||
right_margin,
|
bottom,
|
||||||
bottom_margin,
|
left_margin,
|
||||||
)
|
top_margin,
|
||||||
base_image = source.crop(
|
right_margin,
|
||||||
|
bottom_margin,
|
||||||
|
tile,
|
||||||
|
noise_source,
|
||||||
|
fill_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
base_mask = mask.crop(
|
||||||
(
|
(
|
||||||
left + left_margin,
|
left + left_margin,
|
||||||
top + top_margin,
|
top + top_margin,
|
||||||
|
@ -301,57 +329,60 @@ def process_tile_spiral(
|
||||||
bottom + bottom_margin,
|
bottom + bottom_margin,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tile_image = noise_source(
|
tile_mask = Image.new("L", (tile, tile), color=0)
|
||||||
base_image, (tile, tile), (0, 0), fill=fill_color
|
tile_mask.paste(base_mask, (left_margin, top_margin))
|
||||||
)
|
|
||||||
tile_image.paste(base_image, (left_margin, top_margin))
|
|
||||||
|
|
||||||
if mask:
|
|
||||||
base_mask = mask.crop(
|
|
||||||
(
|
|
||||||
left + left_margin,
|
|
||||||
top + top_margin,
|
|
||||||
right + right_margin,
|
|
||||||
bottom + bottom_margin,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tile_mask = Image.new("L", (tile, tile), color=0)
|
|
||||||
tile_mask.paste(base_mask, (left_margin, top_margin))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("tiling normally")
|
logger.debug("tiling normally")
|
||||||
tile_image = source.crop((left, top, right, bottom))
|
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
||||||
if mask:
|
if mask:
|
||||||
tile_mask = mask.crop((left, top, right, bottom))
|
tile_mask = mask.crop((left, top, right, bottom))
|
||||||
|
|
||||||
for image_filter in filters:
|
for image_filter in filters:
|
||||||
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||||
|
|
||||||
tiles.append((left, top, tile_image))
|
if isinstance(tile_stack, list):
|
||||||
|
tile_stack = StageResult.from_images(tile_stack)
|
||||||
|
|
||||||
if single_tile:
|
tiles.append((left, top, tile_stack.as_image()))
|
||||||
return tile_image
|
|
||||||
else:
|
lefts, tops, stacks = list(zip(*tiles))
|
||||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
coords = list(zip(lefts, tops))
|
||||||
|
stacks = list(zip(*stacks))
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for stack in stacks:
|
||||||
|
stack_tiles = zip(coords, stack)
|
||||||
|
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
|
||||||
|
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def process_tile_order(
|
def process_tile_order(
|
||||||
order: TileOrder,
|
order: TileOrder,
|
||||||
source: Image.Image,
|
stack: StageResult,
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
TODO: needs to handle more than one image
|
||||||
|
"""
|
||||||
if order == TileOrder.grid:
|
if order == TileOrder.grid:
|
||||||
logger.debug("using grid tile order with tile size: %s", tile)
|
logger.debug("using grid tile order with tile size: %s", tile)
|
||||||
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
return process_tile_stack(
|
||||||
|
stack, tile, scale, filters, generate_tile_grid, **kwargs
|
||||||
|
)
|
||||||
elif order == TileOrder.kernel:
|
elif order == TileOrder.kernel:
|
||||||
logger.debug("using kernel tile order with tile size: %s", tile)
|
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
elif order == TileOrder.spiral:
|
elif order == TileOrder.spiral:
|
||||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
return process_tile_stack(
|
||||||
|
stack, tile, scale, filters, generate_tile_spiral, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown tile order: %s", order)
|
logger.warning("unknown tile order: %s", order)
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
@ -445,3 +476,77 @@ def generate_tile_spiral(
|
||||||
height_tile_target -= abs(state.value[1])
|
height_tile_target -= abs(state.value[1])
|
||||||
|
|
||||||
return tile_coords
|
return tile_coords
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tile_grid(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
tile: int,
|
||||||
|
overlap: float = 0.0,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||||
|
tiles_x = ceil(width / adj_tile)
|
||||||
|
tiles_y = ceil(height / adj_tile)
|
||||||
|
total = tiles_x * tiles_y
|
||||||
|
logger.debug(
|
||||||
|
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||||
|
total,
|
||||||
|
tiles_x,
|
||||||
|
tiles_y,
|
||||||
|
adj_tile,
|
||||||
|
overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||||
|
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
left = x * adj_tile
|
||||||
|
top = y * adj_tile
|
||||||
|
|
||||||
|
tiles.append((int(left), int(top)))
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
def get_result_tile(
|
||||||
|
result: StageResult,
|
||||||
|
origin: Tuple[int, int],
|
||||||
|
tile: Size,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
top, left = origin
|
||||||
|
return [
|
||||||
|
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||||
|
for layer in result.as_image()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_margin(
|
||||||
|
stack: List[Image.Image],
|
||||||
|
left: int,
|
||||||
|
top: int,
|
||||||
|
right: int,
|
||||||
|
bottom: int,
|
||||||
|
left_margin: int,
|
||||||
|
top_margin: int,
|
||||||
|
right_margin: int,
|
||||||
|
bottom_margin: int,
|
||||||
|
tile: int,
|
||||||
|
noise_source,
|
||||||
|
fill_color,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
results = []
|
||||||
|
for source in stack:
|
||||||
|
base_image = source.crop(
|
||||||
|
(
|
||||||
|
left + left_margin,
|
||||||
|
top + top_margin,
|
||||||
|
right + right_margin,
|
||||||
|
bottom + bottom_margin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
|
||||||
|
tile_image.paste(base_image, (left_margin, top_margin))
|
||||||
|
results.append(tile_image)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
@ -1,22 +1,30 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..models.onnx import OnnxModel
|
from ..models.onnx import OnnxModel
|
||||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
SizeChart,
|
||||||
|
StageParams,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpscaleBSRGANStage(BaseStage):
|
class UpscaleBSRGANStage(BaseStage):
|
||||||
max_tile = 64
|
max_tile = SizeChart.micro
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
|
@ -54,12 +62,12 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
|
@ -71,40 +79,38 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
bsrgan = self.load(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
image = np.array(source) / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
logger.trace(
|
||||||
|
"BSRGAN output shape: %s",
|
||||||
(
|
(
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
|
||||||
|
|
||||||
dest = bsrgan(image)
|
output = bsrgan(image)
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
|
||||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
|
||||||
|
|
||||||
|
logger.debug("output image shape: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(arrays=outputs)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
) -> int:
|
) -> int:
|
||||||
tile = min(params.tiles, self.max_tile)
|
tile = min(params.unet_tile, self.max_tile)
|
||||||
return size.width // tile * size.height // tile
|
return size.width // tile * size.height // tile
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from ..worker.context import ProgressCallback
|
from ..worker.context import ProgressCallback
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*args,
|
*,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
if highres.scale <= 1:
|
if highres.scale <= 1:
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
chain = stage_highres(stage, params, highres, upscale)
|
chain = stage_highres(stage, params, highres, upscale)
|
||||||
|
|
||||||
return [
|
outputs = [
|
||||||
chain(
|
chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
|
@ -41,5 +42,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
source,
|
source,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
for source in sources
|
for source in sources.as_image()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -18,13 +18,14 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpscaleOutpaintStage(BaseStage):
|
class UpscaleOutpaintStage(BaseStage):
|
||||||
max_tile = SizeChart.unlimited
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
|
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
border: Border,
|
border: Border,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
|
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||||
params
|
params
|
||||||
)
|
)
|
||||||
|
@ -56,12 +57,12 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "tile-source.png", source)
|
save_image(server, "tile-source.png", source)
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
|
@ -71,7 +72,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
outputs.append(source)
|
outputs.append(source)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tile_size = params.tiles
|
tile_size = params.unet_tile
|
||||||
size = Size(*source.size)
|
size = Size(*source.size)
|
||||||
latent_size = size.min(tile_size, tile_size)
|
latent_size = size.min(tile_size, tile_size)
|
||||||
|
|
||||||
|
@ -99,10 +100,11 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# encode and record alternative prompts outside of LPW
|
# encode and record alternative prompts outside of LPW
|
||||||
prompt_embeds = encode_prompt(
|
if not params.is_xl():
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
prompt_embeds = encode_prompt(
|
||||||
)
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
)
|
||||||
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
result = pipe(
|
||||||
|
@ -121,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
|
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxRRDBNet
|
from ..onnx import OnnxRRDBNet
|
||||||
|
@ -10,7 +9,8 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -77,25 +77,22 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||||
|
|
||||||
|
upsampler = self.load(
|
||||||
|
server, upscale, worker.get_device(), tile=stage.tile_size
|
||||||
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
output = np.array(source)
|
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
||||||
upsampler = self.load(
|
logger.info("final output image size: %s", output.shape)
|
||||||
server, upscale, worker.get_device(), tile=stage.tile_size
|
|
||||||
)
|
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
|
||||||
|
|
||||||
output = Image.fromarray(output, "RGB")
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(arrays=outputs)
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
method: str,
|
method: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
if upscale.scale <= 1:
|
if upscale.scale <= 1:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
||||||
|
@ -32,18 +33,20 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||||
|
|
||||||
if method == "bilinear":
|
if method == "bilinear":
|
||||||
logger.debug("using bilinear interpolation for highres")
|
logger.debug("using bilinear interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
outputs.append(
|
||||||
|
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||||
|
)
|
||||||
elif method == "lanczos":
|
elif method == "lanczos":
|
||||||
logger.debug("using Lanczos interpolation for highres")
|
logger.debug("using Lanczos interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
outputs.append(
|
||||||
|
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown upscaling method: %s", method)
|
logger.warning("unknown upscaling method: %s", method)
|
||||||
|
|
||||||
outputs.append(source)
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
|
@ -10,7 +10,8 @@ from ..diffusers.utils import encode_prompt, parse_prompt
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -46,22 +47,23 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
model=path.join(server.model_path, upscale.upscale_model),
|
model=path.join(server.model_path, upscale.upscale_model),
|
||||||
)
|
)
|
||||||
generator = torch.manual_seed(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
prompt_embeds = encode_prompt(
|
if not params.is_xl():
|
||||||
pipeline,
|
prompt_embeds = encode_prompt(
|
||||||
prompt_pairs,
|
pipeline,
|
||||||
num_images_per_prompt=params.batch,
|
prompt_pairs,
|
||||||
do_classifier_free_guidance=params.do_cfg(),
|
num_images_per_prompt=params.batch,
|
||||||
)
|
do_classifier_free_guidance=params.do_cfg(),
|
||||||
pipeline.unet.set_prompts(prompt_embeds)
|
)
|
||||||
|
pipeline.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_image():
|
||||||
result = pipeline(
|
result = pipeline(
|
||||||
prompt,
|
prompt,
|
||||||
source,
|
source,
|
||||||
generator=generator,
|
generator=rng,
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
|
@ -71,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
)
|
)
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,22 +1,23 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..models.onnx import OnnxModel
|
from ..models.onnx import OnnxModel
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from ..server import ModelTypes, ServerContext
|
from ..server import ModelTypes, ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .stage import BaseStage
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpscaleSwinIRStage(BaseStage):
|
class UpscaleSwinIRStage(BaseStage):
|
||||||
max_tile = 64
|
max_tile = SizeChart.micro
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
|
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: StageResult,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> StageResult:
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
|
@ -71,31 +72,30 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources.as_numpy():
|
||||||
# TODO: add support for grayscale (1-channel) images
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = np.array(source) / 255.0
|
image = source / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.trace("SwinIR input shape: %s", image.shape)
|
logger.trace("SwinIR input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
logger.trace(
|
||||||
|
"SwinIR output shape: %s",
|
||||||
(
|
(
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
|
||||||
|
|
||||||
dest = swinir(image)
|
output = swinir(image)
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
logger.info("output image size: %s", output.shape)
|
||||||
logger.info("output image size: %s x %s", output.width, output.height)
|
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,2 +1,5 @@
|
||||||
ONNX_MODEL = "model.onnx"
|
ONNX_MODEL = "model.onnx"
|
||||||
ONNX_WEIGHTS = "weights.pb"
|
ONNX_WEIGHTS = "weights.pb"
|
||||||
|
|
||||||
|
LATENT_FACTOR = 8
|
||||||
|
LATENT_CHANNELS = 4
|
||||||
|
|
|
@ -15,7 +15,8 @@ from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
from ..utils import load_config
|
from ..utils import load_config
|
||||||
from .correction.gfpgan import convert_correction_gfpgan
|
from .correction.gfpgan import convert_correction_gfpgan
|
||||||
from .diffusion.control import convert_diffusion_control
|
from .diffusion.control import convert_diffusion_control
|
||||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
from .diffusion.diffusion import convert_diffusion_diffusers
|
||||||
|
from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
|
||||||
from .diffusion.lora import blend_loras
|
from .diffusion.lora import blend_loras
|
||||||
from .diffusion.textual_inversion import blend_textual_inversions
|
from .diffusion.textual_inversion import blend_textual_inversions
|
||||||
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
||||||
|
@ -357,13 +358,23 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
conversion, name, model["source"], format=model_format
|
conversion, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
converted, dest = convert_diffusion_diffusers(
|
pipeline = model.get("pipeline", "txt2img")
|
||||||
conversion,
|
if pipeline.endswith("-sdxl"):
|
||||||
model,
|
converted, dest = convert_diffusion_diffusers_xl(
|
||||||
source,
|
conversion,
|
||||||
model_format,
|
model,
|
||||||
hf=hf,
|
source,
|
||||||
)
|
model_format,
|
||||||
|
hf=hf,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
converted, dest = convert_diffusion_diffusers(
|
||||||
|
conversion,
|
||||||
|
model,
|
||||||
|
source,
|
||||||
|
model_format,
|
||||||
|
hf=hf,
|
||||||
|
)
|
||||||
|
|
||||||
# make sure blending only happens once, not every run
|
# make sure blending only happens once, not every run
|
||||||
if converted:
|
if converted:
|
||||||
|
@ -588,7 +599,7 @@ def main(args=None) -> int:
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
server = ConversionContext.from_environ()
|
server = ConversionContext.from_environ()
|
||||||
server.half = args.half or "onnx-fp16" in server.optimizations
|
server.half = args.half or server.has_optimization("onnx-fp16")
|
||||||
server.opset = args.opset
|
server.opset = args.opset
|
||||||
server.token = args.token
|
server.token = args.token
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -20,7 +20,6 @@ from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionControlNetPipeline,
|
|
||||||
StableDiffusionInstructPix2PixPipeline,
|
StableDiffusionInstructPix2PixPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
|
@ -32,17 +31,25 @@ from onnx import load_model, save_model
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
from ...diffusers.load import optimize_pipeline
|
from ...diffusers.load import optimize_pipeline
|
||||||
|
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ...utils import run_gc
|
from ...utils import run_gc
|
||||||
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
from ..utils import (
|
||||||
|
RESOLVE_FORMATS,
|
||||||
|
ConversionContext,
|
||||||
|
check_ext,
|
||||||
|
is_torch_2_0,
|
||||||
|
load_tensor,
|
||||||
|
onnx_export,
|
||||||
|
)
|
||||||
from .checkpoint import convert_extract_checkpoint
|
from .checkpoint import convert_extract_checkpoint
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
available_pipelines = {
|
CONVERT_PIPELINES = {
|
||||||
"controlnet": StableDiffusionControlNetPipeline,
|
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||||
"img2img": StableDiffusionPipeline,
|
"img2img": StableDiffusionPipeline,
|
||||||
"inpaint": StableDiffusionPipeline,
|
"inpaint": StableDiffusionPipeline,
|
||||||
"lpw": StableDiffusionPipeline,
|
"lpw": StableDiffusionPipeline,
|
||||||
|
@ -96,7 +103,6 @@ def get_model_version(
|
||||||
opts["prediction_type"] = "epsilon"
|
opts["prediction_type"] = "epsilon"
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("unable to load tensor for version check")
|
logger.debug("unable to load tensor for version check")
|
||||||
pass
|
|
||||||
|
|
||||||
return (v2, opts)
|
return (v2, opts)
|
||||||
|
|
||||||
|
@ -314,7 +320,7 @@ def convert_diffusion_diffusers(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
pipe_class = available_pipelines.get(pipe_type)
|
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||||
v2, pipe_args = get_model_version(
|
v2, pipe_args = get_model_version(
|
||||||
source, conversion.map_location, size=image_size, version=version
|
source, conversion.map_location, size=image_size, version=version
|
||||||
)
|
)
|
||||||
|
@ -360,7 +366,6 @@ def convert_diffusion_diffusers(
|
||||||
source,
|
source,
|
||||||
original_config_file=config_path,
|
original_config_file=config_path,
|
||||||
pipeline_class=pipe_class,
|
pipeline_class=pipe_class,
|
||||||
vae_path=replace_vae,
|
|
||||||
**pipe_args,
|
**pipe_args,
|
||||||
).to(device, torch_dtype=dtype)
|
).to(device, torch_dtype=dtype)
|
||||||
elif hf:
|
elif hf:
|
||||||
|
@ -374,6 +379,17 @@ def convert_diffusion_diffusers(
|
||||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||||
|
|
||||||
|
if replace_vae is not None:
|
||||||
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
|
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||||
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||||
|
else:
|
||||||
|
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||||
|
|
||||||
|
if is_torch_2_0:
|
||||||
|
pipeline.unet.set_attn_processor(AttnProcessor())
|
||||||
|
pipeline.vae.set_attn_processor(AttnProcessor())
|
||||||
|
|
||||||
optimize_pipeline(conversion, pipeline)
|
optimize_pipeline(conversion, pipeline)
|
||||||
|
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
@ -424,9 +440,6 @@ def convert_diffusion_diffusers(
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
||||||
|
|
||||||
if is_torch_2_0:
|
|
||||||
pipeline.unet.set_attn_processor(AttnProcessor())
|
|
||||||
|
|
||||||
unet_in_channels = pipeline.unet.config.in_channels
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
unet_sample_size = pipeline.unet.config.sample_size
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
unet_path = output_path / "unet" / ONNX_MODEL
|
unet_path = output_path / "unet" / ONNX_MODEL
|
||||||
|
@ -526,19 +539,6 @@ def convert_diffusion_diffusers(
|
||||||
del unet
|
del unet
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
# VAE
|
|
||||||
if replace_vae is not None:
|
|
||||||
if replace_vae.startswith("."):
|
|
||||||
logger.debug(
|
|
||||||
"custom VAE appears to be a local path, making it relative to the model path"
|
|
||||||
)
|
|
||||||
replace_vae = path.join(conversion.model_path, replace_vae)
|
|
||||||
|
|
||||||
logger.info("loading custom VAE: %s", replace_vae)
|
|
||||||
vae = AutoencoderKL.from_pretrained(replace_vae)
|
|
||||||
pipeline.vae = vae
|
|
||||||
run_gc()
|
|
||||||
|
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
|
||||||
|
from onnx.shape_inference import infer_shapes_path
|
||||||
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
|
from optimum.exporters.onnx import main_export
|
||||||
|
|
||||||
|
from ...constants import ONNX_MODEL
|
||||||
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_diffusion_diffusers_xl(
|
||||||
|
conversion: ConversionContext,
|
||||||
|
model: Dict,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str],
|
||||||
|
hf: bool = False,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
|
"""
|
||||||
|
name = model.get("name")
|
||||||
|
replace_vae = model.get("vae", None)
|
||||||
|
|
||||||
|
device = conversion.training_device
|
||||||
|
dtype = conversion.torch_dtype()
|
||||||
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||||
|
|
||||||
|
dest_path = path.join(conversion.model_path, name)
|
||||||
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
|
model_hash = path.join(dest_path, "hash.txt")
|
||||||
|
|
||||||
|
# diffusers go into a directory rather than .onnx file
|
||||||
|
logger.info(
|
||||||
|
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if path.exists(dest_path) and path.exists(model_index):
|
||||||
|
logger.info("ONNX model already exists, skipping conversion")
|
||||||
|
|
||||||
|
if "hash" in model and not path.exists(model_hash):
|
||||||
|
logger.info("ONNX model does not have hash file, adding one")
|
||||||
|
with open(model_hash, "w") as f:
|
||||||
|
f.write(model["hash"])
|
||||||
|
|
||||||
|
return (False, dest_path)
|
||||||
|
|
||||||
|
# safetensors -> diffusers directory with torch models
|
||||||
|
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||||
|
|
||||||
|
if format == "safetensors":
|
||||||
|
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
|
source, use_safetensors=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
||||||
|
|
||||||
|
if replace_vae is not None:
|
||||||
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
|
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||||
|
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
||||||
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||||
|
else:
|
||||||
|
logger.debug("loading pretrained VAE from path: %s", vae_path)
|
||||||
|
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||||
|
|
||||||
|
if path.exists(temp_path):
|
||||||
|
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
||||||
|
else:
|
||||||
|
logger.debug("exporting torch model for %s: %s", source, temp_path)
|
||||||
|
pipeline.save_pretrained(temp_path)
|
||||||
|
|
||||||
|
# directory -> onnx using optimum exporters
|
||||||
|
main_export(
|
||||||
|
temp_path,
|
||||||
|
output=dest_path,
|
||||||
|
task="stable-diffusion-xl",
|
||||||
|
device=device,
|
||||||
|
fp16=conversion.has_optimization(
|
||||||
|
"torch-fp16"
|
||||||
|
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
if "hash" in model:
|
||||||
|
logger.debug("adding hash file to ONNX model")
|
||||||
|
with open(model_hash, "w") as f:
|
||||||
|
f.write(model["hash"])
|
||||||
|
|
||||||
|
if conversion.half:
|
||||||
|
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
|
||||||
|
infer_shapes_path(unet_path)
|
||||||
|
unet = onnx.load(unet_path)
|
||||||
|
opt_model = convert_float_to_float16(
|
||||||
|
unet,
|
||||||
|
disable_shape_infer=True,
|
||||||
|
force_fp16_initializers=True,
|
||||||
|
keep_io_types=True,
|
||||||
|
op_block_list=["Attention", "MultiHeadAttention"],
|
||||||
|
)
|
||||||
|
onnx.save_model(
|
||||||
|
opt_model,
|
||||||
|
unet_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location="weights.pb",
|
||||||
|
)
|
||||||
|
|
||||||
|
return False, dest_path
|
|
@ -1,22 +1,15 @@
|
||||||
from argparse import ArgumentParser
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from onnx import ModelProto, load, numpy_helper
|
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
|
||||||
from onnx.checker import check_model
|
from onnx.external_data_helper import set_external_data
|
||||||
from onnx.external_data_helper import (
|
from onnxruntime import OrtValue
|
||||||
convert_model_to_external_data,
|
|
||||||
set_external_data,
|
|
||||||
write_external_data_tensors,
|
|
||||||
)
|
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
|
||||||
from scipy import interpolate
|
from scipy import interpolate
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext, load_tensor
|
from ..utils import load_tensor
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,7 +32,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||||
lr = a
|
lr = a
|
||||||
|
|
||||||
if kernel == (1, 1):
|
if kernel == (1, 1):
|
||||||
lr = np.expand_dims(lr, axis=(2, 3))
|
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
|
||||||
|
|
||||||
return hr + lr
|
return hr + lr
|
||||||
|
|
||||||
|
@ -78,13 +71,15 @@ def fix_node_name(key: str):
|
||||||
return fixed_name
|
return fixed_name
|
||||||
|
|
||||||
|
|
||||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
||||||
fixed = {}
|
fixed = {}
|
||||||
|
names = [fix_node_name(node.name) for node in nodes]
|
||||||
|
|
||||||
for key, value in keys.items():
|
for key, value in keys.items():
|
||||||
root, *rest = key.split(".")
|
root, *_rest = key.split(".")
|
||||||
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
|
logger.trace("fixing XL node name: %s -> %s", key, root)
|
||||||
|
|
||||||
|
simple = False
|
||||||
if root.startswith("input"):
|
if root.startswith("input"):
|
||||||
block = "down_blocks"
|
block = "down_blocks"
|
||||||
elif root.startswith("middle"):
|
elif root.startswith("middle"):
|
||||||
|
@ -93,6 +88,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
block = "up_blocks"
|
block = "up_blocks"
|
||||||
elif root.startswith("text_model"):
|
elif root.startswith("text_model"):
|
||||||
block = "text_model"
|
block = "text_model"
|
||||||
|
elif root.startswith("down_blocks"):
|
||||||
|
block = "down_blocks"
|
||||||
|
simple = True
|
||||||
|
elif root.startswith("mid_block"):
|
||||||
|
block = "mid_block"
|
||||||
|
simple = True
|
||||||
|
elif root.startswith("up_blocks"):
|
||||||
|
block = "up_blocks"
|
||||||
|
simple = True
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown XL key name: %s", key)
|
logger.warning("unknown XL key name: %s", key)
|
||||||
fixed[key] = value
|
fixed[key] = value
|
||||||
|
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
|
|
||||||
suffix = None
|
suffix = None
|
||||||
for s in [
|
for s in [
|
||||||
|
"conv",
|
||||||
|
"conv_shortcut",
|
||||||
|
"conv1",
|
||||||
|
"conv2",
|
||||||
"fc1",
|
"fc1",
|
||||||
"fc2",
|
"fc2",
|
||||||
"ff_net_0_proj",
|
"ff_net_0_proj",
|
||||||
|
@ -119,18 +127,21 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
logger.warning("new XL key type: %s", root)
|
logger.warning("new XL key type: %s", root)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||||
match = None
|
match: Optional[str] = None
|
||||||
if block == "text_model":
|
if "conv" in suffix:
|
||||||
match = next(
|
match = next(node for node in names if node == f"{root}_Conv")
|
||||||
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
elif "time_emb_proj" in root:
|
||||||
)
|
match = next(node for node in names if node == f"{root}_Gemm")
|
||||||
|
elif block == "text_model" or simple:
|
||||||
|
match = next(node for node in names if node == f"{root}_MatMul")
|
||||||
else:
|
else:
|
||||||
|
# search in order. one side has sparse indices, so they will not match.
|
||||||
match = next(
|
match = next(
|
||||||
node
|
node
|
||||||
for node in nodes
|
for node in names
|
||||||
if node.name.startswith(f"/{block}")
|
if node.startswith(block)
|
||||||
and fix_node_name(node.name).endswith(
|
and node.endswith(
|
||||||
f"{suffix}_MatMul"
|
f"{suffix}_MatMul"
|
||||||
) # needs to be fixed because some places use to_out.0
|
) # needs to be fixed because some places use to_out.0
|
||||||
)
|
)
|
||||||
|
@ -138,18 +149,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
if match is None:
|
if match is None:
|
||||||
logger.warning("no matches for XL key: %s", root)
|
logger.warning("no matches for XL key: %s", root)
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
|
logger.trace("matched key: %s -> %s", key, match)
|
||||||
|
|
||||||
name: str = match.name
|
name = match
|
||||||
name = fix_node_name(name.rstrip("/MatMul"))
|
if name.endswith("_MatMul"):
|
||||||
|
name = name[:-7]
|
||||||
|
elif name.endswith("_Gemm"):
|
||||||
|
name = name[:-5]
|
||||||
|
elif name.endswith("_Conv"):
|
||||||
|
name = name[:-5]
|
||||||
|
|
||||||
if name.endswith("proj_o"):
|
logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
|
||||||
# wtf
|
|
||||||
name = f"{name}ut"
|
|
||||||
|
|
||||||
logger.debug("matching XL key with node: %s -> %s", key, match.name)
|
|
||||||
|
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
nodes.remove(match)
|
names.remove(match)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
|
||||||
|
len(fixed.keys()),
|
||||||
|
len(keys.keys()),
|
||||||
|
len(names),
|
||||||
|
)
|
||||||
|
|
||||||
return fixed
|
return fixed
|
||||||
|
|
||||||
|
@ -161,6 +182,245 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_weights_loha(
|
||||||
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||||
|
) -> Tuple[str, np.ndarray]:
|
||||||
|
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
||||||
|
|
||||||
|
t1_key = key.replace("hada_w1_a", "hada_t1")
|
||||||
|
t2_key = key.replace("hada_w1_a", "hada_t2")
|
||||||
|
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
|
||||||
|
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
|
||||||
|
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
|
||||||
|
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
|
||||||
|
logger.trace(
|
||||||
|
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
|
||||||
|
key,
|
||||||
|
w1b_key,
|
||||||
|
w2a_key,
|
||||||
|
w2b_key,
|
||||||
|
alpha_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
w1a_weight = lora_model[key].to(dtype=dtype)
|
||||||
|
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
|
||||||
|
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
|
||||||
|
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
|
||||||
|
|
||||||
|
t1_weight = lora_model.get(t1_key, None)
|
||||||
|
t2_weight = lora_model.get(t2_key, None)
|
||||||
|
|
||||||
|
dim = w1b_weight.size()[0]
|
||||||
|
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
||||||
|
|
||||||
|
if t1_weight is not None and t2_weight is not None:
|
||||||
|
t1_weight = t1_weight.to(dtype=dtype)
|
||||||
|
t2_weight = t2_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
logger.trace(
|
||||||
|
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
||||||
|
t1_weight.shape,
|
||||||
|
w1a_weight.shape,
|
||||||
|
w1b_weight.shape,
|
||||||
|
t2_weight.shape,
|
||||||
|
w2a_weight.shape,
|
||||||
|
w2b_weight.shape,
|
||||||
|
)
|
||||||
|
weights_1 = torch.einsum(
|
||||||
|
"i j k l, j r, i p -> p r k l",
|
||||||
|
t1_weight,
|
||||||
|
w1b_weight,
|
||||||
|
w1a_weight,
|
||||||
|
)
|
||||||
|
weights_2 = torch.einsum(
|
||||||
|
"i j k l, j r, i p -> p r k l",
|
||||||
|
t2_weight,
|
||||||
|
w2b_weight,
|
||||||
|
w2a_weight,
|
||||||
|
)
|
||||||
|
weights = weights_1 * weights_2
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
else:
|
||||||
|
logger.trace(
|
||||||
|
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
||||||
|
w1a_weight.shape,
|
||||||
|
w1b_weight.shape,
|
||||||
|
w2a_weight.shape,
|
||||||
|
w2b_weight.shape,
|
||||||
|
)
|
||||||
|
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
|
||||||
|
return base_key, np_weights
|
||||||
|
|
||||||
|
|
||||||
|
def blend_weights_lora(
|
||||||
|
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||||
|
) -> Tuple[str, np.ndarray]:
|
||||||
|
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||||
|
|
||||||
|
mid_key = key.replace("lora_down", "lora_mid")
|
||||||
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
|
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
|
||||||
|
|
||||||
|
down_weight = lora_model[key].to(dtype=dtype)
|
||||||
|
up_weight = lora_model[up_key].to(dtype=dtype)
|
||||||
|
|
||||||
|
mid_weight = None
|
||||||
|
if mid_key in lora_model:
|
||||||
|
mid_weight = lora_model[mid_key].to(dtype=dtype)
|
||||||
|
|
||||||
|
dim = down_weight.size()[0]
|
||||||
|
alpha = lora_model.get(alpha_key, dim)
|
||||||
|
|
||||||
|
if not isinstance(alpha, int):
|
||||||
|
alpha = alpha.to(dtype).numpy()
|
||||||
|
|
||||||
|
kernel = down_weight.shape[-2:]
|
||||||
|
if mid_weight is not None:
|
||||||
|
kernel = mid_weight.shape[-2:]
|
||||||
|
|
||||||
|
if len(down_weight.size()) == 2:
|
||||||
|
# blend for nn.Linear
|
||||||
|
logger.trace(
|
||||||
|
"blending weights for Linear node: (%s @ %s) * %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
weights = up_weight @ down_weight
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
elif len(down_weight.size()) == 4 and kernel == (
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
):
|
||||||
|
# blend for nn.Conv2d 1x1
|
||||||
|
logger.trace(
|
||||||
|
"blending weights for Conv 1x1 node: %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
weights = (
|
||||||
|
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
|
||||||
|
.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
)
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
elif len(down_weight.size()) == 4 and kernel == (
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
):
|
||||||
|
if mid_weight is not None:
|
||||||
|
# blend for nn.Conv2d 3x3 with CP decomp
|
||||||
|
logger.trace(
|
||||||
|
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
mid_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||||
|
|
||||||
|
for w in range(kernel[0]):
|
||||||
|
for h in range(kernel[1]):
|
||||||
|
weights[:, :, w, h] = (
|
||||||
|
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
|
||||||
|
) @ down_weight.squeeze(3).squeeze(2)
|
||||||
|
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
else:
|
||||||
|
# blend for nn.Conv2d 3x3
|
||||||
|
logger.trace(
|
||||||
|
"blending weights for Conv 3x3 node: %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||||
|
|
||||||
|
for w in range(kernel[0]):
|
||||||
|
for h in range(kernel[1]):
|
||||||
|
down_w, down_h = kernel_slice(w, h, down_weight.shape)
|
||||||
|
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
||||||
|
|
||||||
|
weights[:, :, w, h] = (
|
||||||
|
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
|
||||||
|
)
|
||||||
|
|
||||||
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"unknown LoRA node type at %s: %s",
|
||||||
|
base_key,
|
||||||
|
up_weight.shape[-2:],
|
||||||
|
)
|
||||||
|
# TODO: should this be None?
|
||||||
|
np_weights = np.zeros((1, 1, 1, 1))
|
||||||
|
|
||||||
|
return base_key, np_weights
|
||||||
|
|
||||||
|
|
||||||
|
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
|
||||||
|
# blending
|
||||||
|
onnx_weights = numpy_helper.to_array(weight_node)
|
||||||
|
logger.trace(
|
||||||
|
"found blended weights for conv: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
|
weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if onnx_weights.shape[-2:] == (1, 1):
|
||||||
|
if weights.shape[-2:] == (1, 1):
|
||||||
|
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
|
else:
|
||||||
|
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||||
|
|
||||||
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
|
else:
|
||||||
|
if onnx_weights.shape != weights.shape:
|
||||||
|
logger.warning(
|
||||||
|
"reshaping weights for mismatched Conv node: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
|
weights.shape,
|
||||||
|
)
|
||||||
|
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
|
||||||
|
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||||
|
else:
|
||||||
|
blended = onnx_weights + weights
|
||||||
|
|
||||||
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
|
# replace the original initializer
|
||||||
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
|
||||||
|
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||||
|
logger.trace(
|
||||||
|
"found blended weights for matmul: %s, %s",
|
||||||
|
weights.shape,
|
||||||
|
onnx_weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_weights = weights.transpose()
|
||||||
|
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
||||||
|
logger.warning(
|
||||||
|
"weight shapes do not match for %s: %s vs %s",
|
||||||
|
matmul_key,
|
||||||
|
weights.shape,
|
||||||
|
onnx_weights.shape,
|
||||||
|
)
|
||||||
|
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||||
|
|
||||||
|
blended = onnx_weights + t_weights
|
||||||
|
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
||||||
|
|
||||||
|
# replace the original initializer
|
||||||
|
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
def blend_loras(
|
||||||
_conversion: ServerContext,
|
_conversion: ServerContext,
|
||||||
base_name: Union[str, ModelProto],
|
base_name: Union[str, ModelProto],
|
||||||
|
@ -184,246 +444,77 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
lora_prefix = f"lora_{model_type}_"
|
lora_prefix = f"lora_{model_type}_"
|
||||||
|
|
||||||
blended: Dict[str, np.ndarray] = {}
|
layers = []
|
||||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
if lora_model is None:
|
if lora_model is None:
|
||||||
logger.warning("unable to load tensor for LoRA")
|
logger.warning("unable to load tensor for LoRA")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
blended: Dict[str, np.ndarray] = {}
|
||||||
|
layers.append(blended)
|
||||||
|
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if ".hada_w1_a" in key and lora_prefix in key:
|
if ".hada_w1_a" in key and lora_prefix in key:
|
||||||
# LoHA
|
# LoHA
|
||||||
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
base_key, np_weights = blend_weights_loha(
|
||||||
|
key, lora_prefix, lora_model, dtype
|
||||||
t1_key = key.replace("hada_w1_a", "hada_t1")
|
|
||||||
t2_key = key.replace("hada_w1_a", "hada_t2")
|
|
||||||
w1b_key = key.replace("hada_w1_a", "hada_w1_b")
|
|
||||||
w2a_key = key.replace("hada_w1_a", "hada_w2_a")
|
|
||||||
w2b_key = key.replace("hada_w1_a", "hada_w2_b")
|
|
||||||
alpha_key = key[: key.index("hada_w1_a")] + "alpha"
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for LoHA keys: %s, %s, %s, %s, %s",
|
|
||||||
key,
|
|
||||||
w1b_key,
|
|
||||||
w2a_key,
|
|
||||||
w2b_key,
|
|
||||||
alpha_key,
|
|
||||||
)
|
)
|
||||||
|
np_weights = np_weights * lora_weight
|
||||||
w1a_weight = lora_model[key].to(dtype=dtype)
|
logger.trace(
|
||||||
w1b_weight = lora_model[w1b_key].to(dtype=dtype)
|
"adding LoHA weights: %s",
|
||||||
w2a_weight = lora_model[w2a_key].to(dtype=dtype)
|
np_weights.shape,
|
||||||
w2b_weight = lora_model[w2b_key].to(dtype=dtype)
|
)
|
||||||
|
blended[base_key] = np_weights
|
||||||
t1_weight = lora_model.get(t1_key, None)
|
|
||||||
t2_weight = lora_model.get(t2_key, None)
|
|
||||||
|
|
||||||
dim = w1b_weight.size()[0]
|
|
||||||
alpha = lora_model.get(alpha_key, dim).to(dtype).numpy()
|
|
||||||
|
|
||||||
if t1_weight is not None and t2_weight is not None:
|
|
||||||
t1_weight = t1_weight.to(dtype=dtype)
|
|
||||||
t2_weight = t2_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
logger.trace(
|
|
||||||
"composing weights for LoHA node: (%s, %s, %s) * (%s, %s, %s)",
|
|
||||||
t1_weight.shape,
|
|
||||||
w1a_weight.shape,
|
|
||||||
w1b_weight.shape,
|
|
||||||
t2_weight.shape,
|
|
||||||
w2a_weight.shape,
|
|
||||||
w2b_weight.shape,
|
|
||||||
)
|
|
||||||
weights_1 = torch.einsum(
|
|
||||||
"i j k l, j r, i p -> p r k l",
|
|
||||||
t1_weight,
|
|
||||||
w1b_weight,
|
|
||||||
w1a_weight,
|
|
||||||
)
|
|
||||||
weights_2 = torch.einsum(
|
|
||||||
"i j k l, j r, i p -> p r k l",
|
|
||||||
t2_weight,
|
|
||||||
w2b_weight,
|
|
||||||
w2a_weight,
|
|
||||||
)
|
|
||||||
weights = weights_1 * weights_2
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
else:
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for LoHA node: (%s @ %s) * (%s @ %s)",
|
|
||||||
w1a_weight.shape,
|
|
||||||
w1b_weight.shape,
|
|
||||||
w2a_weight.shape,
|
|
||||||
w2b_weight.shape,
|
|
||||||
)
|
|
||||||
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
|
|
||||||
np_weights *= lora_weight
|
|
||||||
if base_key in blended:
|
|
||||||
logger.trace(
|
|
||||||
"summing LoHA weights: %s + %s",
|
|
||||||
blended[base_key].shape,
|
|
||||||
np_weights.shape,
|
|
||||||
)
|
|
||||||
blended[base_key] += sum_weights(blended[base_key], np_weights)
|
|
||||||
else:
|
|
||||||
blended[base_key] = np_weights
|
|
||||||
elif ".lora_down" in key and lora_prefix in key:
|
elif ".lora_down" in key and lora_prefix in key:
|
||||||
# LoRA or LoCON
|
# LoRA or LoCON
|
||||||
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
base_key, np_weights = blend_weights_lora(
|
||||||
|
key, lora_prefix, lora_model, dtype
|
||||||
mid_key = key.replace("lora_down", "lora_mid")
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
|
||||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
|
|
||||||
)
|
)
|
||||||
|
np_weights = np_weights * lora_weight
|
||||||
|
logger.trace(
|
||||||
|
"adding LoRA weights: %s",
|
||||||
|
np_weights.shape,
|
||||||
|
)
|
||||||
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
down_weight = lora_model[key].to(dtype=dtype)
|
# rewrite node names for XL and flatten layers
|
||||||
up_weight = lora_model[up_key].to(dtype=dtype)
|
weights: Dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
mid_weight = None
|
for blended in layers:
|
||||||
if mid_key in lora_model:
|
if xl:
|
||||||
mid_weight = lora_model[mid_key].to(dtype=dtype)
|
nodes = list(base_model.graph.node)
|
||||||
|
blended = fix_xl_names(blended, nodes)
|
||||||
|
|
||||||
dim = down_weight.size()[0]
|
for key, value in blended.items():
|
||||||
alpha = lora_model.get(alpha_key, dim)
|
if key in weights:
|
||||||
|
weights[key] = sum_weights(weights[key], value)
|
||||||
if not isinstance(alpha, int):
|
else:
|
||||||
alpha = alpha.to(dtype).numpy()
|
weights[key] = value
|
||||||
|
|
||||||
kernel = down_weight.shape[-2:]
|
|
||||||
if mid_weight is not None:
|
|
||||||
kernel = mid_weight.shape[-2:]
|
|
||||||
|
|
||||||
if len(down_weight.size()) == 2:
|
|
||||||
# blend for nn.Linear
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for Linear node: (%s @ %s) * %s",
|
|
||||||
down_weight.shape,
|
|
||||||
up_weight.shape,
|
|
||||||
alpha,
|
|
||||||
)
|
|
||||||
weights = up_weight @ down_weight
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
elif len(down_weight.size()) == 4 and kernel == (
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
):
|
|
||||||
# blend for nn.Conv2d 1x1
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for Conv 1x1 node: %s, %s, %s",
|
|
||||||
down_weight.shape,
|
|
||||||
up_weight.shape,
|
|
||||||
alpha,
|
|
||||||
)
|
|
||||||
weights = (
|
|
||||||
(
|
|
||||||
up_weight.squeeze(3).squeeze(2)
|
|
||||||
@ down_weight.squeeze(3).squeeze(2)
|
|
||||||
)
|
|
||||||
.unsqueeze(2)
|
|
||||||
.unsqueeze(3)
|
|
||||||
)
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
elif len(down_weight.size()) == 4 and kernel == (
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
):
|
|
||||||
if mid_weight is not None:
|
|
||||||
# blend for nn.Conv2d 3x3 with CP decomp
|
|
||||||
logger.trace(
|
|
||||||
"composing weights for Conv 3x3 node: %s, %s, %s, %s",
|
|
||||||
down_weight.shape,
|
|
||||||
up_weight.shape,
|
|
||||||
mid_weight.shape,
|
|
||||||
alpha,
|
|
||||||
)
|
|
||||||
weights = torch.zeros(
|
|
||||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
|
||||||
)
|
|
||||||
|
|
||||||
for w in range(kernel[0]):
|
|
||||||
for h in range(kernel[1]):
|
|
||||||
weights[:, :, w, h] = (
|
|
||||||
up_weight.squeeze(3).squeeze(2)
|
|
||||||
@ mid_weight[:, :, w, h]
|
|
||||||
) @ down_weight.squeeze(3).squeeze(2)
|
|
||||||
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
else:
|
|
||||||
# blend for nn.Conv2d 3x3
|
|
||||||
logger.trace(
|
|
||||||
"blending weights for Conv 3x3 node: %s, %s, %s",
|
|
||||||
down_weight.shape,
|
|
||||||
up_weight.shape,
|
|
||||||
alpha,
|
|
||||||
)
|
|
||||||
weights = torch.zeros(
|
|
||||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
|
||||||
)
|
|
||||||
|
|
||||||
for w in range(kernel[0]):
|
|
||||||
for h in range(kernel[1]):
|
|
||||||
down_w, down_h = kernel_slice(w, h, down_weight.shape)
|
|
||||||
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
|
||||||
|
|
||||||
weights[:, :, w, h] = (
|
|
||||||
up_weight[:, :, up_w, up_h]
|
|
||||||
@ down_weight[:, :, down_w, down_h]
|
|
||||||
)
|
|
||||||
|
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"unknown LoRA node type at %s: %s",
|
|
||||||
base_key,
|
|
||||||
up_weight.shape[-2:],
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
np_weights *= lora_weight
|
|
||||||
if base_key in blended:
|
|
||||||
logger.trace(
|
|
||||||
"summing weights: %s + %s",
|
|
||||||
blended[base_key].shape,
|
|
||||||
np_weights.shape,
|
|
||||||
)
|
|
||||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
|
||||||
else:
|
|
||||||
blended[base_key] = np_weights
|
|
||||||
|
|
||||||
# rewrite node names for XL
|
|
||||||
if xl:
|
|
||||||
nodes = list(base_model.graph.node)
|
|
||||||
blended = fix_xl_names(blended, nodes)
|
|
||||||
|
|
||||||
logger.trace(
|
|
||||||
"updating %s of %s initializers",
|
|
||||||
len(blended.keys()),
|
|
||||||
len(base_model.graph.initializer),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# fix node names once
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||||
]
|
]
|
||||||
logger.trace("fixed initializer names: %s", fixed_initializer_names)
|
|
||||||
|
|
||||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||||
logger.trace("fixed node names: %s", fixed_node_names)
|
|
||||||
|
logger.debug(
|
||||||
|
"updating %s of %s initializers",
|
||||||
|
len(weights.keys()),
|
||||||
|
len(base_model.graph.initializer),
|
||||||
|
)
|
||||||
|
|
||||||
unmatched_keys = []
|
unmatched_keys = []
|
||||||
for base_key, weights in blended.items():
|
for base_key, weights in weights.items():
|
||||||
conv_key = base_key + "_Conv"
|
conv_key = base_key + "_Conv"
|
||||||
gemm_key = base_key + "_Gemm"
|
gemm_key = base_key + "_Gemm"
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"key %s has conv: %s, matmul: %s",
|
"key %s has conv: %s, gemm: %s, matmul: %s",
|
||||||
base_key,
|
base_key,
|
||||||
conv_key in fixed_node_names,
|
conv_key in fixed_node_names,
|
||||||
|
gemm_key in fixed_node_names,
|
||||||
matmul_key in fixed_node_names,
|
matmul_key in fixed_node_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -449,38 +540,9 @@ def blend_loras(
|
||||||
weight_node = base_model.graph.initializer[weight_idx]
|
weight_node = base_model.graph.initializer[weight_idx]
|
||||||
logger.trace("found weight initializer: %s", weight_node.name)
|
logger.trace("found weight initializer: %s", weight_node.name)
|
||||||
|
|
||||||
# blending
|
# replace the previous node
|
||||||
onnx_weights = numpy_helper.to_array(weight_node)
|
updated_node = blend_node_conv_gemm(weight_node, weights)
|
||||||
logger.trace(
|
|
||||||
"found blended weights for conv: %s, %s",
|
|
||||||
onnx_weights.shape,
|
|
||||||
weights.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
if onnx_weights.shape[-2:] == (1, 1):
|
|
||||||
if weights.shape[-2:] == (1, 1):
|
|
||||||
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
|
||||||
else:
|
|
||||||
blended = onnx_weights.squeeze((3, 2)) + weights
|
|
||||||
|
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
|
||||||
else:
|
|
||||||
if onnx_weights.shape != weights.shape:
|
|
||||||
logger.warning(
|
|
||||||
"reshaping weights for mismatched Conv node: %s, %s",
|
|
||||||
onnx_weights.shape,
|
|
||||||
weights.shape,
|
|
||||||
)
|
|
||||||
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
|
||||||
else:
|
|
||||||
blended = onnx_weights + weights
|
|
||||||
|
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
|
||||||
|
|
||||||
# replace the original initializer
|
|
||||||
updated_node = numpy_helper.from_array(
|
|
||||||
blended.astype(onnx_weights.dtype), weight_node.name
|
|
||||||
)
|
|
||||||
del base_model.graph.initializer[weight_idx]
|
del base_model.graph.initializer[weight_idx]
|
||||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
|
@ -497,42 +559,15 @@ def blend_loras(
|
||||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||||
logger.trace("found matmul initializer: %s", matmul_node.name)
|
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||||
|
|
||||||
# blending
|
# replace the previous node
|
||||||
onnx_weights = numpy_helper.to_array(matmul_node)
|
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
|
||||||
logger.trace(
|
|
||||||
"found blended weights for matmul: %s, %s",
|
|
||||||
weights.shape,
|
|
||||||
onnx_weights.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
t_weights = weights.transpose()
|
|
||||||
if (
|
|
||||||
weights.shape != onnx_weights.shape
|
|
||||||
and t_weights.shape != onnx_weights.shape
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"weight shapes do not match for %s: %s vs %s",
|
|
||||||
matmul_key,
|
|
||||||
weights.shape,
|
|
||||||
onnx_weights.shape,
|
|
||||||
)
|
|
||||||
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
|
||||||
|
|
||||||
blended = onnx_weights + t_weights
|
|
||||||
logger.debug(
|
|
||||||
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace the original initializer
|
|
||||||
updated_node = numpy_helper.from_array(
|
|
||||||
blended.astype(onnx_weights.dtype), matmul_node.name
|
|
||||||
)
|
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
else:
|
else:
|
||||||
unmatched_keys.append(base_key)
|
unmatched_keys.append(base_key)
|
||||||
|
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"node counts: %s -> %s, %s -> %s",
|
"node counts: %s -> %s, %s -> %s",
|
||||||
len(fixed_initializer_names),
|
len(fixed_initializer_names),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
|
@ -541,10 +576,7 @@ def blend_loras(
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(unmatched_keys) > 0:
|
if len(unmatched_keys) > 0:
|
||||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
|
||||||
|
|
||||||
# if model_type == "unet":
|
|
||||||
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
|
|
||||||
|
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
|
@ -568,63 +600,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
||||||
logger.debug("weights after interpolation: %s", output.shape)
|
logger.debug("weights after interpolation: %s", output.shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
context = ConversionContext.from_environ()
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("--base", type=str)
|
|
||||||
parser.add_argument("--dest", type=str)
|
|
||||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
|
||||||
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
|
||||||
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.info(
|
|
||||||
"merging %s with %s with weights: %s",
|
|
||||||
args.lora_models,
|
|
||||||
args.base,
|
|
||||||
args.lora_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
default_weight = 1.0 / len(args.lora_models)
|
|
||||||
while len(args.lora_weights) < len(args.lora_models):
|
|
||||||
args.lora_weights.append(default_weight)
|
|
||||||
|
|
||||||
blend_model = blend_loras(
|
|
||||||
context,
|
|
||||||
args.base,
|
|
||||||
list(zip(args.lora_models, args.lora_weights)),
|
|
||||||
args.type,
|
|
||||||
)
|
|
||||||
if args.dest is None or args.dest == "" or args.dest == ":load":
|
|
||||||
# convert to external data and save to memory
|
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
|
||||||
logger.info("saved external data for %s nodes", len(external_data))
|
|
||||||
|
|
||||||
external_names, external_values = zip(*external_data)
|
|
||||||
opts = SessionOptions()
|
|
||||||
opts.add_external_initializers(list(external_names), list(external_values))
|
|
||||||
sess = InferenceSession(
|
|
||||||
bare_model.SerializeToString(),
|
|
||||||
sess_options=opts,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
convert_model_to_external_data(
|
|
||||||
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
|
||||||
)
|
|
||||||
bare_model = write_external_data_tensors(blend_model, args.dest)
|
|
||||||
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
|
||||||
|
|
||||||
with open(dest_file, "w+b") as model_file:
|
|
||||||
model_file.write(bare_model.SerializeToString())
|
|
||||||
|
|
||||||
logger.info("successfully saved blended model: %s", dest_file)
|
|
||||||
|
|
||||||
check_model(dest_file)
|
|
||||||
|
|
||||||
logger.info("checked blended model")
|
|
||||||
|
|
|
@ -14,19 +14,155 @@ from ..utils import ConversionContext, load_tensor
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_embedding_format(loaded_embeds) -> str:
|
||||||
|
keys: List[str] = list(loaded_embeds.keys())
|
||||||
|
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||||
|
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||||
|
return "concept"
|
||||||
|
elif "emb_params" in keys:
|
||||||
|
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
|
||||||
|
return "parameters"
|
||||||
|
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||||
|
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
||||||
|
return "embeddings"
|
||||||
|
else:
|
||||||
|
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
# separate token and the embeds
|
||||||
|
token = list(loaded_embeds.keys())[0]
|
||||||
|
|
||||||
|
layer = loaded_embeds[token].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = layer
|
||||||
|
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
emb_params = loaded_embeds["emb_params"]
|
||||||
|
|
||||||
|
num_tokens = emb_params.shape[0]
|
||||||
|
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||||
|
|
||||||
|
sum_layer = np.zeros(emb_params[0, :].shape)
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token = f"{base_token}-{i}"
|
||||||
|
layer = emb_params[i, :].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
sum_layer += layer
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
# add base and sum tokens to embeds
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = sum_layer
|
||||||
|
|
||||||
|
sum_token = f"{base_token}-all"
|
||||||
|
if sum_token in embeds:
|
||||||
|
embeds[sum_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[sum_token] = sum_layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
||||||
|
# separate token and embeds
|
||||||
|
token = list(string_to_token.keys())[0]
|
||||||
|
trained_embeds = string_to_param[token]
|
||||||
|
|
||||||
|
num_tokens = trained_embeds.shape[0]
|
||||||
|
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||||
|
|
||||||
|
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token = f"{base_token}-{i}"
|
||||||
|
layer = trained_embeds[i, :].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
sum_layer += layer
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
# add base and sum tokens to embeds
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = sum_layer
|
||||||
|
|
||||||
|
sum_token = f"{base_token}-all"
|
||||||
|
if sum_token in embeds:
|
||||||
|
embeds[sum_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[sum_token] = sum_layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
|
||||||
|
# resize the token embeddings
|
||||||
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
embedding_node = [
|
||||||
|
n
|
||||||
|
for n in text_encoder.graph.initializer
|
||||||
|
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||||
|
][0]
|
||||||
|
base_weights = numpy_helper.to_array(embedding_node)
|
||||||
|
|
||||||
|
weights_dim = base_weights.shape[1]
|
||||||
|
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
||||||
|
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
||||||
|
|
||||||
|
for token, weights in embeds.items():
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
||||||
|
embedding_weights[token_id] = weights
|
||||||
|
|
||||||
|
# replace embedding_node
|
||||||
|
for i in range(len(text_encoder.graph.initializer)):
|
||||||
|
if (
|
||||||
|
text_encoder.graph.initializer[i].name
|
||||||
|
== "text_model.embeddings.token_embedding.weight"
|
||||||
|
):
|
||||||
|
new_initializer = numpy_helper.from_array(
|
||||||
|
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
||||||
|
)
|
||||||
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||||
|
del text_encoder.graph.initializer[i]
|
||||||
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def blend_textual_inversions(
|
def blend_textual_inversions(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
text_encoder: ModelProto,
|
text_encoder: ModelProto,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = np.float32
|
dtype = np.float32
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, weight, base_token, inversion_format in inversions:
|
for name, weight, base_token, format in embeddings:
|
||||||
if base_token is None:
|
if base_token is None:
|
||||||
logger.debug("no base token provided, using name: %s", name)
|
logger.debug("no base token provided, using name: %s", name)
|
||||||
base_token = name
|
base_token = name
|
||||||
|
@ -43,153 +179,28 @@ def blend_textual_inversions(
|
||||||
logger.warning("unable to load tensor")
|
logger.warning("unable to load tensor")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if inversion_format is None:
|
if format is None:
|
||||||
keys: List[str] = list(loaded_embeds.keys())
|
format = detect_embedding_format(loaded_embeds)
|
||||||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
|
||||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
|
||||||
inversion_format = "concept"
|
|
||||||
elif "emb_params" in keys:
|
|
||||||
logger.debug(
|
|
||||||
"detected Textual Inversion parameter embeddings: %s", keys
|
|
||||||
)
|
|
||||||
inversion_format = "parameters"
|
|
||||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
|
||||||
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
|
||||||
inversion_format = "embeddings"
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"unknown Textual Inversion format, no recognized keys: %s", keys
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if inversion_format == "concept":
|
if format == "concept":
|
||||||
# separate token and the embeds
|
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
token = list(loaded_embeds.keys())[0]
|
elif format == "parameters":
|
||||||
|
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
layer = loaded_embeds[token].numpy().astype(dtype)
|
elif format == "embeddings":
|
||||||
layer *= weight
|
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = layer
|
|
||||||
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
elif inversion_format == "parameters":
|
|
||||||
emb_params = loaded_embeds["emb_params"]
|
|
||||||
|
|
||||||
num_tokens = emb_params.shape[0]
|
|
||||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
|
||||||
|
|
||||||
sum_layer = np.zeros(emb_params[0, :].shape)
|
|
||||||
|
|
||||||
for i in range(num_tokens):
|
|
||||||
token = f"{base_token}-{i}"
|
|
||||||
layer = emb_params[i, :].numpy().astype(dtype)
|
|
||||||
layer *= weight
|
|
||||||
|
|
||||||
sum_layer += layer
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
|
|
||||||
# add base and sum tokens to embeds
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = sum_layer
|
|
||||||
|
|
||||||
sum_token = f"{base_token}-all"
|
|
||||||
if sum_token in embeds:
|
|
||||||
embeds[sum_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[sum_token] = sum_layer
|
|
||||||
elif inversion_format == "embeddings":
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
|
||||||
|
|
||||||
# separate token and embeds
|
|
||||||
token = list(string_to_token.keys())[0]
|
|
||||||
trained_embeds = string_to_param[token]
|
|
||||||
|
|
||||||
num_tokens = trained_embeds.shape[0]
|
|
||||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
|
||||||
|
|
||||||
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
|
||||||
|
|
||||||
for i in range(num_tokens):
|
|
||||||
token = f"{base_token}-{i}"
|
|
||||||
layer = trained_embeds[i, :].numpy().astype(dtype)
|
|
||||||
layer *= weight
|
|
||||||
|
|
||||||
sum_layer += layer
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
|
|
||||||
# add base and sum tokens to embeds
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = sum_layer
|
|
||||||
|
|
||||||
sum_token = f"{base_token}-all"
|
|
||||||
if sum_token in embeds:
|
|
||||||
embeds[sum_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[sum_token] = sum_layer
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||||
|
|
||||||
# add the tokens to the tokenizer
|
# add the tokens to the tokenizer
|
||||||
logger.debug(
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||||
"found embeddings for %s tokens: %s",
|
if num_added_tokens == 0:
|
||||||
len(embeds.keys()),
|
raise ValueError(
|
||||||
list(embeds.keys()),
|
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
|
||||||
)
|
)
|
||||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
|
||||||
if num_added_tokens == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.trace("added %s tokens", num_added_tokens)
|
logger.trace("added %s tokens", num_added_tokens)
|
||||||
|
|
||||||
# resize the token embeddings
|
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
|
||||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
|
||||||
embedding_node = [
|
|
||||||
n
|
|
||||||
for n in text_encoder.graph.initializer
|
|
||||||
if n.name == "text_model.embeddings.token_embedding.weight"
|
|
||||||
][0]
|
|
||||||
base_weights = numpy_helper.to_array(embedding_node)
|
|
||||||
|
|
||||||
weights_dim = base_weights.shape[1]
|
|
||||||
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
|
||||||
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
|
||||||
|
|
||||||
for token, weights in embeds.items():
|
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
||||||
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
|
||||||
embedding_weights[token_id] = weights
|
|
||||||
|
|
||||||
# replace embedding_node
|
|
||||||
for i in range(len(text_encoder.graph.initializer)):
|
|
||||||
if (
|
|
||||||
text_encoder.graph.initializer[i].name
|
|
||||||
== "text_model.embeddings.token_embedding.weight"
|
|
||||||
):
|
|
||||||
new_initializer = numpy_helper.from_array(
|
|
||||||
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
|
||||||
)
|
|
||||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
|
||||||
del text_encoder.graph.initializer[i]
|
|
||||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
|
||||||
|
|
||||||
return (text_encoder, tokenizer)
|
return (text_encoder, tokenizer)
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
|
||||||
class ConversionContext(ServerContext):
|
class ConversionContext(ServerContext):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: Optional[str] = None,
|
model_path: str = ".",
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
half: bool = False,
|
half: bool = False,
|
||||||
|
@ -69,7 +69,7 @@ class ConversionContext(ServerContext):
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
context = super().from_environ()
|
context = super().from_environ()
|
||||||
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
||||||
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True)
|
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
|
||||||
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
||||||
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
||||||
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
||||||
|
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
|
|
||||||
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *_rest = model
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
scale = rest[0] if len(rest) > 0 else 1
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
single_vae = rest[0] if len(rest) > 0 else False
|
single_vae = rest.pop(0) if len(rest) > 0 else False
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
scale = rest[0] if len(rest) > 0 else 1
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
|
|
||||||
|
|
||||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
|
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||||
|
|
||||||
|
|
||||||
|
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
||||||
|
_name, ext = path.splitext(name)
|
||||||
|
ext = ext.strip(".")
|
||||||
|
|
||||||
|
return (ext in exts, ext)
|
||||||
|
|
||||||
|
|
||||||
def source_format(model: Dict) -> Optional[str]:
|
def source_format(model: Dict) -> Optional[str]:
|
||||||
|
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
return model["format"]
|
return model["format"]
|
||||||
|
|
||||||
if "source" in model:
|
if "source" in model:
|
||||||
_name, ext = path.splitext(model["source"])
|
valid, ext = check_ext(model["source"], MODEL_FORMATS)
|
||||||
if ext in MODEL_FORMATS:
|
if valid:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -298,6 +305,7 @@ def onnx_export(
|
||||||
half=False,
|
half=False,
|
||||||
external_data=False,
|
external_data=False,
|
||||||
v2=False,
|
v2=False,
|
||||||
|
op_block_list=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
|
@ -316,8 +324,7 @@ def onnx_export(
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
op_block_list = None
|
if v2 and op_block_list is None:
|
||||||
if v2:
|
|
||||||
op_block_list = ["Attention", "MultiHeadAttention"]
|
op_block_list = ["Attention", "MultiHeadAttention"]
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from onnx import load_model
|
from onnx import load_model
|
||||||
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
||||||
ORTStableDiffusionXLImg2ImgPipeline,
|
ORTStableDiffusionXLImg2ImgPipeline,
|
||||||
ORTStableDiffusionXLPipeline,
|
ORTStableDiffusionXLPipeline,
|
||||||
)
|
)
|
||||||
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
|
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from ..constants import ONNX_MODEL
|
from ..constants import LATENT_FACTOR, ONNX_MODEL
|
||||||
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
|
@ -24,6 +23,7 @@ from .patches.vae import VAEWrapper
|
||||||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||||
|
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
||||||
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
||||||
from .version_safe_diffusers import (
|
from .version_safe_diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
@ -38,6 +38,7 @@ from .version_safe_diffusers import (
|
||||||
KarrasVeScheduler,
|
KarrasVeScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
|
LCMScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
@ -58,6 +59,7 @@ available_pipelines = {
|
||||||
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
|
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
|
||||||
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
|
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
|
||||||
"panorama": OnnxStableDiffusionPanoramaPipeline,
|
"panorama": OnnxStableDiffusionPanoramaPipeline,
|
||||||
|
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
|
||||||
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
|
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
|
||||||
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
|
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
|
||||||
"txt2img": OnnxStableDiffusionPipeline,
|
"txt2img": OnnxStableDiffusionPipeline,
|
||||||
|
@ -77,12 +79,25 @@ pipeline_schedulers = {
|
||||||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||||
"karras-ve": KarrasVeScheduler,
|
"karras-ve": KarrasVeScheduler,
|
||||||
|
"lcm": LCMScheduler,
|
||||||
"lms-discrete": LMSDiscreteScheduler,
|
"lms-discrete": LMSDiscreteScheduler,
|
||||||
"pndm": PNDMScheduler,
|
"pndm": PNDMScheduler,
|
||||||
"unipc-multi": UniPCMultistepScheduler,
|
"unipc-multi": UniPCMultistepScheduler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_pipeline(name: str, pipeline: Any) -> bool:
|
||||||
|
global available_pipelines
|
||||||
|
|
||||||
|
if name in available_pipelines:
|
||||||
|
# TODO: decide if this should be allowed or not
|
||||||
|
logger.warning("cannot replace existing pipeline: %s", name)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
available_pipelines[name] = pipeline
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_available_pipelines() -> List[str]:
|
def get_available_pipelines() -> List[str]:
|
||||||
return list(available_pipelines.keys())
|
return list(available_pipelines.keys())
|
||||||
|
|
||||||
|
@ -99,16 +114,19 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
pipeline: str,
|
pipeline: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
embeddings: Optional[List[Tuple[str, float]]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
):
|
):
|
||||||
inversions = inversions or []
|
embeddings = embeddings or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
model = model or params.model
|
model = model or params.model
|
||||||
|
|
||||||
|
@ -122,7 +140,7 @@ def load_pipeline(
|
||||||
device.device,
|
device.device,
|
||||||
device.provider,
|
device.provider,
|
||||||
control_key,
|
control_key,
|
||||||
inversions,
|
embeddings,
|
||||||
loras,
|
loras,
|
||||||
)
|
)
|
||||||
scheduler_key = (params.scheduler, model)
|
scheduler_key = (params.scheduler, model)
|
||||||
|
@ -159,211 +177,376 @@ def load_pipeline(
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
logger.debug("loading new diffusion pipeline from %s", model)
|
logger.debug("loading new diffusion pipeline from %s", model)
|
||||||
|
scheduler = scheduler_type.from_pretrained(
|
||||||
|
model,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
subfolder="scheduler",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
components = {
|
components = {
|
||||||
"scheduler": scheduler_type.from_pretrained(
|
"scheduler": scheduler,
|
||||||
model,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
subfolder="scheduler",
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# shared components
|
# shared components
|
||||||
text_encoder = None
|
|
||||||
unet_type = "unet"
|
unet_type = "unet"
|
||||||
|
|
||||||
# ControlNet component
|
# ControlNet component
|
||||||
if params.is_control() and params.control is not None:
|
if params.is_control() and params.control is not None:
|
||||||
cnet_path = path.join(
|
logger.debug("loading ControlNet components")
|
||||||
server.model_path, "control", f"{params.control.name}.onnx"
|
control_components = load_controlnet(server, device, params)
|
||||||
)
|
components.update(control_components)
|
||||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
|
||||||
components["controlnet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
cnet_path,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_type = "cnet"
|
unet_type = "cnet"
|
||||||
|
|
||||||
# Textual Inversion blending
|
# load various pipeline components
|
||||||
if inversions is not None and len(inversions) > 0:
|
encoder_components = load_text_encoders(
|
||||||
logger.debug("blending Textual Inversions from %s", inversions)
|
server, device, model, embeddings, loras, torch_dtype, params
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
)
|
||||||
|
components.update(encoder_components)
|
||||||
|
|
||||||
inversion_models = [
|
unet_components = load_unet(server, device, model, loras, unet_type, params)
|
||||||
path.join(server.model_path, "inversion", name)
|
components.update(unet_components)
|
||||||
for name in inversion_names
|
|
||||||
]
|
vae_components = load_vae(server, device, model, params)
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
components.update(vae_components)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
|
||||||
model,
|
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||||
subfolder="tokenizer",
|
|
||||||
torch_dtype=torch_dtype,
|
if params.is_xl():
|
||||||
|
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
|
||||||
|
pipe = pipeline_class(
|
||||||
|
components["vae_decoder_session"],
|
||||||
|
components["text_encoder_session"],
|
||||||
|
components["unet_session"],
|
||||||
|
{
|
||||||
|
"force_zeros_for_empty_prompt": True,
|
||||||
|
"requires_aesthetics_score": False,
|
||||||
|
},
|
||||||
|
components["tokenizer"],
|
||||||
|
scheduler,
|
||||||
|
vae_encoder_session=components.get("vae_encoder_session", None),
|
||||||
|
text_encoder_2_session=components.get("text_encoder_2_session", None),
|
||||||
|
tokenizer_2=components.get("tokenizer_2", None),
|
||||||
)
|
)
|
||||||
text_encoder, tokenizer = blend_textual_inversions(
|
else:
|
||||||
|
if "vae" in components:
|
||||||
|
# upscale uses a single VAE
|
||||||
|
logger.debug(
|
||||||
|
"assembling SD pipeline for %s with single VAE",
|
||||||
|
pipeline_class.__name__,
|
||||||
|
)
|
||||||
|
pipe = pipeline_class(
|
||||||
|
components["vae"],
|
||||||
|
components["text_encoder"],
|
||||||
|
components["tokenizer"],
|
||||||
|
components["unet"],
|
||||||
|
scheduler,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"assembling SD pipeline for %s with VAE codec",
|
||||||
|
pipeline_class.__name__,
|
||||||
|
)
|
||||||
|
pipe = pipeline_class(
|
||||||
|
components["vae_encoder"],
|
||||||
|
components["vae_decoder"],
|
||||||
|
components["text_encoder"],
|
||||||
|
components["tokenizer"],
|
||||||
|
components["unet"],
|
||||||
|
scheduler,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not server.show_progress:
|
||||||
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
optimize_pipeline(server, pipe)
|
||||||
|
patch_pipeline(server, pipe, pipeline_class, params)
|
||||||
|
|
||||||
|
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||||
|
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
||||||
|
|
||||||
|
for vae in VAE_COMPONENTS:
|
||||||
|
if hasattr(pipe, vae):
|
||||||
|
vae_model = getattr(pipe, vae)
|
||||||
|
if isinstance(vae_model, VAEWrapper):
|
||||||
|
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||||
|
vae_model.set_window_size(
|
||||||
|
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
# update panorama params
|
||||||
|
if params.is_panorama():
|
||||||
|
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
|
||||||
|
logger.debug(
|
||||||
|
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||||
|
params.unet_tile,
|
||||||
|
unet_stride,
|
||||||
|
params.vae_tile,
|
||||||
|
params.vae_overlap,
|
||||||
|
)
|
||||||
|
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
|
||||||
|
|
||||||
|
run_gc([device])
|
||||||
|
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
|
||||||
|
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
||||||
|
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||||
|
components = {}
|
||||||
|
components["controlnet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
cnet_path,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def load_text_encoders(
|
||||||
|
server: ServerContext,
|
||||||
|
device: DeviceParams,
|
||||||
|
model: str,
|
||||||
|
embeddings: Optional[List[Tuple[str, float]]],
|
||||||
|
loras: Optional[List[Tuple[str, float]]],
|
||||||
|
torch_dtype,
|
||||||
|
params: ImageParams,
|
||||||
|
):
|
||||||
|
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
model,
|
||||||
|
subfolder="tokenizer",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
components = {
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||||
|
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||||
|
model,
|
||||||
|
subfolder="tokenizer_2",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
components["tokenizer_2"] = tokenizer_2
|
||||||
|
|
||||||
|
# blend embeddings, if any
|
||||||
|
if embeddings is not None and len(embeddings) > 0:
|
||||||
|
embedding_names, embedding_weights = zip(*embeddings)
|
||||||
|
embedding_models = [
|
||||||
|
path.join(server.model_path, "inversion", name) for name in embedding_names
|
||||||
|
]
|
||||||
|
logger.debug(
|
||||||
|
"blending base model %s with embeddings from %s", model, embedding_models
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: blend text_encoder_2 as well
|
||||||
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
|
server,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
embedding_models,
|
||||||
|
embedding_weights,
|
||||||
|
embedding_names,
|
||||||
|
[None] * len(embedding_models),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
components["tokenizer"] = tokenizer
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2, tokenizer_2 = blend_textual_inversions(
|
||||||
server,
|
server,
|
||||||
text_encoder,
|
text_encoder_2,
|
||||||
tokenizer,
|
tokenizer_2,
|
||||||
list(
|
list(
|
||||||
zip(
|
zip(
|
||||||
inversion_models,
|
embedding_models,
|
||||||
inversion_weights,
|
embedding_weights,
|
||||||
inversion_names,
|
embedding_names,
|
||||||
[None] * len(inversion_models),
|
[None] * len(embedding_models),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
components["tokenizer_2"] = tokenizer_2
|
||||||
|
|
||||||
components["tokenizer"] = tokenizer
|
# blend LoRAs, if any
|
||||||
|
if loras is not None and len(loras) > 0:
|
||||||
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
lora_models = [
|
||||||
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
|
]
|
||||||
|
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
|
||||||
|
|
||||||
# should be pretty small and should not need external data
|
# blend and load text encoder
|
||||||
if loras is None or len(loras) == 0:
|
text_encoder = blend_loras(
|
||||||
# TODO: handle XL encoders
|
server,
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
text_encoder,
|
||||||
OnnxRuntimeModel.load_model(
|
list(zip(lora_models, lora_weights)),
|
||||||
text_encoder.SerializeToString(),
|
"text_encoder",
|
||||||
provider=device.ort_provider("text-encoder"),
|
1 if params.is_xl() else None,
|
||||||
sess_options=device.sess_options(),
|
params.is_xl(),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA blending
|
if params.is_xl():
|
||||||
if loras is not None and len(loras) > 0:
|
text_encoder_2 = blend_loras(
|
||||||
lora_names, lora_weights = zip(*loras)
|
|
||||||
lora_models = [
|
|
||||||
path.join(server.model_path, "lora", name) for name in lora_names
|
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"blending base model %s with LoRA models: %s", model, lora_models
|
|
||||||
)
|
|
||||||
|
|
||||||
# blend and load text encoder
|
|
||||||
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
|
|
||||||
text_encoder = blend_loras(
|
|
||||||
server,
|
server,
|
||||||
text_encoder,
|
text_encoder_2,
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"text_encoder",
|
"text_encoder",
|
||||||
1 if params.is_xl() else None,
|
2,
|
||||||
params.is_xl(),
|
params.is_xl(),
|
||||||
)
|
)
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
|
||||||
text_encoder
|
# prepare external data for sessions
|
||||||
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
||||||
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
|
text_encoder_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_opts.add_external_initializers(
|
||||||
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
# encoder 2 only exists in XL
|
||||||
|
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||||
|
text_encoder_2
|
||||||
|
)
|
||||||
|
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
||||||
|
text_encoder_2_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_2_opts.add_external_initializers(
|
||||||
|
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
# session for te1
|
||||||
|
text_encoder_session = InferenceSession(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
components["text_encoder_session"] = text_encoder_session
|
||||||
|
|
||||||
|
# session for te2
|
||||||
|
text_encoder_2_session = InferenceSession(
|
||||||
|
text_encoder_2.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_2_opts,
|
||||||
|
)
|
||||||
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
components["text_encoder_2_session"] = text_encoder_2_session
|
||||||
|
else:
|
||||||
|
# session for te
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider("text-encoder"),
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
)
|
||||||
text_encoder_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_opts.add_external_initializers(
|
return components
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
|
||||||
|
|
||||||
|
def load_unet(
|
||||||
|
server: ServerContext,
|
||||||
|
device: DeviceParams,
|
||||||
|
model: str,
|
||||||
|
loras: List[Tuple[str, float]],
|
||||||
|
unet_type: Literal["cnet", "unet"],
|
||||||
|
params: ImageParams,
|
||||||
|
):
|
||||||
|
components = {}
|
||||||
|
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
||||||
|
|
||||||
|
# LoRA blending
|
||||||
|
if loras is not None and len(loras) > 0:
|
||||||
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
lora_models = [
|
||||||
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
|
]
|
||||||
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
|
# blend and load unet
|
||||||
|
unet = blend_loras(
|
||||||
|
server,
|
||||||
|
unet,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"unet",
|
||||||
|
xl=params.is_xl(),
|
||||||
|
)
|
||||||
|
|
||||||
|
(unet_model, unet_data) = buffer_external_data_tensors(unet)
|
||||||
|
unet_names, unet_values = zip(*unet_data)
|
||||||
|
unet_opts = device.sess_options(cache=False)
|
||||||
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
unet_session = InferenceSession(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("unet")],
|
||||||
|
sess_options=unet_opts,
|
||||||
|
)
|
||||||
|
unet_session._model_path = path.join(model, "unet")
|
||||||
|
components["unet_session"] = unet_session
|
||||||
|
else:
|
||||||
|
components["unet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
provider=device.ort_provider("unet"),
|
||||||
|
sess_options=unet_opts,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if params.is_xl():
|
return components
|
||||||
text_encoder_session = InferenceSession(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
|
||||||
components["text_encoder_session"] = text_encoder_session
|
|
||||||
else:
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
provider=device.ort_provider("text-encoder"),
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
|
||||||
text_encoder_2 = blend_loras(
|
|
||||||
server,
|
|
||||||
text_encoder_2,
|
|
||||||
list(zip(lora_models, lora_weights)),
|
|
||||||
"text_encoder",
|
|
||||||
2,
|
|
||||||
params.is_xl(),
|
|
||||||
)
|
|
||||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
|
||||||
text_encoder_2
|
|
||||||
)
|
|
||||||
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
|
||||||
text_encoder_2_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_2_opts.add_external_initializers(
|
|
||||||
list(text_encoder_2_names), list(text_encoder_2_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_2_session = InferenceSession(
|
def load_vae(
|
||||||
text_encoder_2.SerializeToString(),
|
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||||
providers=[device.ort_provider("text-encoder")],
|
):
|
||||||
sess_options=text_encoder_2_opts,
|
# one or more VAE models need to be loaded
|
||||||
)
|
vae = path.join(model, "vae", ONNX_MODEL)
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
||||||
components["text_encoder_2_session"] = text_encoder_2_session
|
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
||||||
|
|
||||||
# blend and load unet
|
components = {}
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
if not params.is_xl() and path.exists(vae):
|
||||||
blended_unet = blend_loras(
|
logger.debug("loading VAE from %s", vae)
|
||||||
server,
|
components["vae"] = OnnxRuntimeModel(
|
||||||
unet,
|
OnnxRuntimeModel.load_model(
|
||||||
list(zip(lora_models, lora_weights)),
|
vae,
|
||||||
"unet",
|
provider=device.ort_provider("vae"),
|
||||||
xl=params.is_xl(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||||
unet_opts = device.sess_options(cache=False)
|
if params.is_xl():
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
if params.is_xl():
|
vae_decoder,
|
||||||
unet_session = InferenceSession(
|
provider=device.ort_provider("vae"),
|
||||||
unet_model.SerializeToString(),
|
sess_options=device.sess_options(),
|
||||||
providers=[device.ort_provider("unet")],
|
|
||||||
sess_options=unet_opts,
|
|
||||||
)
|
|
||||||
unet_session._model_path = path.join(model, "unet")
|
|
||||||
components["unet_session"] = unet_session
|
|
||||||
else:
|
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
unet_model.SerializeToString(),
|
|
||||||
provider=device.ort_provider("unet"),
|
|
||||||
sess_options=unet_opts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure a UNet has been loaded
|
|
||||||
if not params.is_xl() and "unet" not in components:
|
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
|
||||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
unet,
|
|
||||||
provider=device.ort_provider("unet"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
components["vae_decoder_session"]._model_path = vae_decoder
|
||||||
|
|
||||||
# one or more VAE models need to be loaded
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
vae = path.join(model, "vae", ONNX_MODEL)
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
vae_encoder,
|
||||||
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
provider=device.ort_provider("vae"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
if not params.is_xl() and path.exists(vae):
|
|
||||||
logger.debug("loading VAE from %s", vae)
|
|
||||||
components["vae"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
vae,
|
|
||||||
provider=device.ort_provider("vae"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif (
|
components["vae_encoder_session"]._model_path = vae_encoder
|
||||||
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
|
|
||||||
):
|
else:
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
components["vae_decoder"] = OnnxRuntimeModel(
|
components["vae_decoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
|
@ -382,119 +565,44 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# additional options for panorama pipeline
|
return components
|
||||||
if params.is_panorama():
|
|
||||||
components["window"] = params.tiles // 8
|
|
||||||
components["stride"] = params.stride // 8
|
|
||||||
|
|
||||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
|
||||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
|
||||||
pipe = pipeline_class.from_pretrained(
|
|
||||||
model,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
safety_checker=None,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
**components,
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure XL models are actually being used
|
|
||||||
# TODO: why is this needed?
|
|
||||||
if "text_encoder_session" in components:
|
|
||||||
logger.info(
|
|
||||||
"text encoder matches: %s, %s",
|
|
||||||
pipe.text_encoder.session == components["text_encoder_session"],
|
|
||||||
type(pipe.text_encoder),
|
|
||||||
)
|
|
||||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
|
||||||
|
|
||||||
if "text_encoder_2_session" in components:
|
|
||||||
logger.info(
|
|
||||||
"text encoder 2 matches: %s, %s",
|
|
||||||
pipe.text_encoder_2.session == components["text_encoder_2_session"],
|
|
||||||
type(pipe.text_encoder_2),
|
|
||||||
)
|
|
||||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
|
||||||
text_encoder_2_session, text_encoder_2
|
|
||||||
)
|
|
||||||
|
|
||||||
if "unet_session" in components:
|
|
||||||
logger.info(
|
|
||||||
"unet matches: %s, %s",
|
|
||||||
pipe.unet.session == components["unet_session"],
|
|
||||||
type(pipe.unet),
|
|
||||||
)
|
|
||||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
|
||||||
|
|
||||||
if not server.show_progress:
|
|
||||||
pipe.set_progress_bar_config(disable=True)
|
|
||||||
|
|
||||||
optimize_pipeline(server, pipe)
|
|
||||||
|
|
||||||
if not params.is_xl():
|
|
||||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
|
||||||
|
|
||||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
|
||||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
|
||||||
|
|
||||||
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
|
|
||||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
|
||||||
|
|
||||||
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
|
|
||||||
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
|
||||||
|
|
||||||
# update panorama params
|
|
||||||
if params.is_panorama():
|
|
||||||
latent_window = params.tiles // 8
|
|
||||||
latent_stride = params.stride // 8
|
|
||||||
|
|
||||||
pipe.set_window_size(latent_window, latent_stride)
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
|
||||||
pipe.vae_decoder.set_window_size(latent_window, params.overlap)
|
|
||||||
if hasattr(pipe, "vae_encoder"):
|
|
||||||
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
|
|
||||||
|
|
||||||
run_gc([device])
|
|
||||||
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_pipeline(
|
def optimize_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
if (
|
if server.has_optimization(
|
||||||
"diffusers-attention-slicing" in server.optimizations
|
"diffusers-attention-slicing"
|
||||||
or "diffusers-attention-slicing-auto" in server.optimizations
|
) or server.has_optimization("diffusers-attention-slicing-auto"):
|
||||||
):
|
|
||||||
logger.debug("enabling auto attention slicing on SD pipeline")
|
logger.debug("enabling auto attention slicing on SD pipeline")
|
||||||
try:
|
try:
|
||||||
pipe.enable_attention_slicing(slice_size="auto")
|
pipe.enable_attention_slicing(slice_size="auto")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error while enabling auto attention slicing: %s", e)
|
logger.warning("error while enabling auto attention slicing: %s", e)
|
||||||
|
|
||||||
if "diffusers-attention-slicing-max" in server.optimizations:
|
if server.has_optimization("diffusers-attention-slicing-max"):
|
||||||
logger.debug("enabling max attention slicing on SD pipeline")
|
logger.debug("enabling max attention slicing on SD pipeline")
|
||||||
try:
|
try:
|
||||||
pipe.enable_attention_slicing(slice_size="max")
|
pipe.enable_attention_slicing(slice_size="max")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error while enabling max attention slicing: %s", e)
|
logger.warning("error while enabling max attention slicing: %s", e)
|
||||||
|
|
||||||
if "diffusers-vae-slicing" in server.optimizations:
|
if server.has_optimization("diffusers-vae-slicing"):
|
||||||
logger.debug("enabling VAE slicing on SD pipeline")
|
logger.debug("enabling VAE slicing on SD pipeline")
|
||||||
try:
|
try:
|
||||||
pipe.enable_vae_slicing()
|
pipe.enable_vae_slicing()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error while enabling VAE slicing: %s", e)
|
logger.warning("error while enabling VAE slicing: %s", e)
|
||||||
|
|
||||||
if "diffusers-cpu-offload-sequential" in server.optimizations:
|
if server.has_optimization("diffusers-cpu-offload-sequential"):
|
||||||
logger.debug("enabling sequential CPU offload on SD pipeline")
|
logger.debug("enabling sequential CPU offload on SD pipeline")
|
||||||
try:
|
try:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error while enabling sequential CPU offload: %s", e)
|
logger.warning("error while enabling sequential CPU offload: %s", e)
|
||||||
|
|
||||||
elif "diffusers-cpu-offload-model" in server.optimizations:
|
elif server.has_optimization("diffusers-cpu-offload-model"):
|
||||||
# TODO: check for accelerate
|
# TODO: check for accelerate
|
||||||
logger.debug("enabling model CPU offload on SD pipeline")
|
logger.debug("enabling model CPU offload on SD pipeline")
|
||||||
try:
|
try:
|
||||||
|
@ -502,7 +610,7 @@ def optimize_pipeline(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("error while enabling model CPU offload: %s", e)
|
logger.warning("error while enabling model CPU offload: %s", e)
|
||||||
|
|
||||||
if "diffusers-memory-efficient-attention" in server.optimizations:
|
if server.has_optimization("diffusers-memory-efficient-attention"):
|
||||||
# TODO: check for xformers
|
# TODO: check for xformers
|
||||||
logger.debug("enabling memory efficient attention for SD pipeline")
|
logger.debug("enabling memory efficient attention for SD pipeline")
|
||||||
try:
|
try:
|
||||||
|
@ -514,17 +622,17 @@ def optimize_pipeline(
|
||||||
def patch_pipeline(
|
def patch_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
pipe_type: str,
|
|
||||||
pipeline: Any,
|
pipeline: Any,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("patching SD pipeline")
|
logger.debug("patching SD pipeline")
|
||||||
|
|
||||||
if pipe_type != "lpw":
|
if not params.is_lpw() and not params.is_xl():
|
||||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||||
|
|
||||||
original_unet = pipe.unet
|
original_unet = pipe.unet
|
||||||
pipe.unet = UNetWrapper(server, original_unet)
|
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
|
||||||
|
logger.debug("patched UNet with wrapper")
|
||||||
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
if hasattr(pipe, "vae_decoder"):
|
||||||
original_decoder = pipe.vae_decoder
|
original_decoder = pipe.vae_decoder
|
||||||
|
@ -532,18 +640,21 @@ def patch_pipeline(
|
||||||
server,
|
server,
|
||||||
original_decoder,
|
original_decoder,
|
||||||
decoder=True,
|
decoder=True,
|
||||||
window=params.tiles,
|
window=params.unet_tile,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
|
logger.debug("patched VAE decoder with wrapper")
|
||||||
|
|
||||||
|
if hasattr(pipe, "vae_encoder"):
|
||||||
original_encoder = pipe.vae_encoder
|
original_encoder = pipe.vae_encoder
|
||||||
pipe.vae_encoder = VAEWrapper(
|
pipe.vae_encoder = VAEWrapper(
|
||||||
server,
|
server,
|
||||||
original_encoder,
|
original_encoder,
|
||||||
decoder=False,
|
decoder=False,
|
||||||
window=params.tiles,
|
window=params.unet_tile,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
elif hasattr(pipe, "vae"):
|
logger.debug("patched VAE encoder with wrapper")
|
||||||
pass # TODO: current wrapper does not work with upscaling VAE
|
|
||||||
else:
|
if hasattr(pipe, "vae"):
|
||||||
logger.debug("no VAE found to patch")
|
logger.warning("not patching single VAE, tiled VAE may not work")
|
||||||
|
|
|
@ -14,20 +14,23 @@ class UNetWrapper(object):
|
||||||
prompt_index: int = 0
|
prompt_index: int = 0
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
wrapped: OnnxRuntimeModel
|
wrapped: OnnxRuntimeModel
|
||||||
|
xl: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
wrapped: OnnxRuntimeModel,
|
wrapped: OnnxRuntimeModel,
|
||||||
|
xl: bool,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
self.xl = xl
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sample: np.ndarray = None,
|
sample: Optional[np.ndarray] = None,
|
||||||
timestep: np.ndarray = None,
|
timestep: Optional[np.ndarray] = None,
|
||||||
encoder_hidden_states: np.ndarray = None,
|
encoder_hidden_states: Optional[np.ndarray] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
logger.trace(
|
logger.trace(
|
||||||
|
@ -43,13 +46,21 @@ class UNetWrapper(object):
|
||||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||||
self.prompt_index += 1
|
self.prompt_index += 1
|
||||||
|
|
||||||
if sample.dtype != timestep.dtype:
|
if self.xl:
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
if sample.dtype != encoder_hidden_states.dtype:
|
||||||
sample = sample.astype(timestep.dtype)
|
logger.trace(
|
||||||
|
"converting UNet sample to hidden state dtype for XL: %s",
|
||||||
|
encoder_hidden_states.dtype,
|
||||||
|
)
|
||||||
|
sample = sample.astype(encoder_hidden_states.dtype)
|
||||||
|
else:
|
||||||
|
if sample.dtype != timestep.dtype:
|
||||||
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
if encoder_hidden_states.dtype != timestep.dtype:
|
if encoder_hidden_states.dtype != timestep.dtype:
|
||||||
logger.trace("converting UNet hidden states to timestep dtype")
|
logger.trace("converting UNet hidden states to timestep dtype")
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
|
|
|
@ -12,8 +12,6 @@ from ...server import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
LATENT_CHANNELS = 4
|
|
||||||
|
|
||||||
|
|
||||||
class VAEWrapper(object):
|
class VAEWrapper(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -39,11 +37,17 @@ class VAEWrapper(object):
|
||||||
self.tile_overlap_factor = overlap
|
self.tile_overlap_factor = overlap
|
||||||
|
|
||||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||||
|
model = (
|
||||||
|
self.wrapped.model
|
||||||
|
if hasattr(self.wrapped, "model")
|
||||||
|
else self.wrapped.session
|
||||||
|
)
|
||||||
|
|
||||||
# set timestep dtype to input type
|
# set timestep dtype to input type
|
||||||
sample_dtype = next(
|
sample_dtype = next(
|
||||||
(
|
(
|
||||||
input.type
|
input.type
|
||||||
for input in self.wrapped.model.get_inputs()
|
for input in model.get_inputs()
|
||||||
if input.name == "sample" or input.name == "latent_sample"
|
if input.name == "sample" or input.name == "latent_sample"
|
||||||
),
|
),
|
||||||
"tensor(float)",
|
"tensor(float)",
|
||||||
|
|
|
@ -13,8 +13,8 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
|
|
@ -13,25 +13,36 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List, Optional, Union
|
from math import ceil
|
||||||
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||||
|
|
||||||
|
from ...chain.tile import make_tile_mask
|
||||||
|
from ...constants import LATENT_CHANNELS, LATENT_FACTOR
|
||||||
|
from ...params import Size
|
||||||
|
from ..utils import (
|
||||||
|
expand_latents,
|
||||||
|
parse_regions,
|
||||||
|
random_seed,
|
||||||
|
repair_nan,
|
||||||
|
resize_latent_shape,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# inpaint constants
|
# inpaint constants
|
||||||
NUM_UNET_INPUT_CHANNELS = 9
|
NUM_UNET_INPUT_CHANNELS = 9
|
||||||
NUM_LATENT_CHANNELS = 4
|
|
||||||
|
|
||||||
DEFAULT_WINDOW = 32
|
DEFAULT_WINDOW = 32
|
||||||
DEFAULT_STRIDE = 8
|
DEFAULT_STRIDE = 8
|
||||||
|
@ -346,13 +357,16 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
def get_views(
|
||||||
|
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
|
||||||
|
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
|
||||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||||
panorama_height /= 8
|
panorama_height /= 8
|
||||||
panorama_width /= 8
|
panorama_width /= 8
|
||||||
|
|
||||||
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
|
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
|
||||||
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
|
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
|
||||||
|
|
||||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"panorama generated %s views, %s by %s blocks",
|
"panorama generated %s views, %s by %s blocks",
|
||||||
|
@ -369,7 +383,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
w_end = w_start + window_size
|
w_end = w_start + window_size
|
||||||
views.append((h_start, h_end, w_start, w_end))
|
views.append((h_start, h_end, w_start, w_end))
|
||||||
|
|
||||||
return views
|
return (views, (h_end * 8, w_end * 8))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def text2img(
|
def text2img(
|
||||||
|
@ -479,6 +493,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
prompt, regions = parse_regions(prompt)
|
||||||
|
|
||||||
prompt_embeds = self._encode_prompt(
|
prompt_embeds = self._encode_prompt(
|
||||||
prompt,
|
prompt,
|
||||||
num_images_per_prompt,
|
num_images_per_prompt,
|
||||||
|
@ -488,9 +504,30 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 3.b. Encode region prompts
|
||||||
|
region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
|
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
|
||||||
|
if region_prompt.endswith("+"):
|
||||||
|
region_prompt = region_prompt[:-1] + " " + prompt
|
||||||
|
|
||||||
|
region_prompt_embeds = self._encode_prompt(
|
||||||
|
region_prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
region_embeds.append(region_prompt_embeds)
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
latents_dtype = prompt_embeds.dtype
|
latents_dtype = prompt_embeds.dtype
|
||||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
latents_shape = (
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
LATENT_CHANNELS,
|
||||||
|
height // LATENT_FACTOR,
|
||||||
|
width // LATENT_FACTOR,
|
||||||
|
)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||||
elif latents.shape != latents_shape:
|
elif latents.shape != latents_shape:
|
||||||
|
@ -525,11 +562,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
count = np.zeros_like(latents)
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
value = np.zeros_like(latents)
|
|
||||||
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
# adjust latents
|
||||||
|
latents = expand_latents(
|
||||||
|
latents,
|
||||||
|
random_seed(generator),
|
||||||
|
Size(resize[1], resize[0]),
|
||||||
|
sigma=self.scheduler.init_noise_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
|
last = i == (len(self.scheduler.timesteps) - 1)
|
||||||
count.fill(0)
|
count.fill(0)
|
||||||
value.fill(0)
|
value.fill(0)
|
||||||
|
|
||||||
|
@ -576,13 +624,115 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
|
if not last:
|
||||||
|
for r, region in enumerate(regions):
|
||||||
|
top, left, bottom, right, weight, feather, prompt = region
|
||||||
|
logger.debug(
|
||||||
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
|
top,
|
||||||
|
left,
|
||||||
|
bottom,
|
||||||
|
right,
|
||||||
|
weight,
|
||||||
|
feather,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert coordinates to latent space
|
||||||
|
h_start = top // LATENT_FACTOR
|
||||||
|
h_end = bottom // LATENT_FACTOR
|
||||||
|
w_start = left // LATENT_FACTOR
|
||||||
|
w_end = right // LATENT_FACTOR
|
||||||
|
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
logger.trace(
|
||||||
|
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||||
|
h_start,
|
||||||
|
h_end,
|
||||||
|
w_start,
|
||||||
|
w_end,
|
||||||
|
latents_for_region.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_region_input = (
|
||||||
|
np.concatenate([latents_for_region] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_region
|
||||||
|
)
|
||||||
|
latent_region_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_region_input), t
|
||||||
|
)
|
||||||
|
latent_region_input = latent_region_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
region_noise_pred = self.unet(
|
||||||
|
sample=latent_region_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=region_embeds[r],
|
||||||
|
)
|
||||||
|
region_noise_pred = region_noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||||
|
region_noise_pred, 2
|
||||||
|
)
|
||||||
|
region_noise_pred = (
|
||||||
|
region_noise_pred_uncond
|
||||||
|
+ guidance_scale
|
||||||
|
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(region_noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_region),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
)
|
||||||
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
if feather[0] > 0.0:
|
||||||
|
mask = make_tile_mask(
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
feather[0],
|
||||||
|
feather[1],
|
||||||
|
)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
else:
|
||||||
|
mask = 1
|
||||||
|
|
||||||
|
if weight >= 100.0:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||||
|
latents_region_denoised * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
|
else:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
|
latents_region_denoised * weight * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
latents = repair_nan(latents)
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[
|
||||||
|
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||||
|
]
|
||||||
|
|
||||||
|
latents = np.clip(latents, -4, +4)
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
@ -828,9 +978,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
count = np.zeros_like(latents)
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
value = np.zeros_like(latents)
|
|
||||||
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
# adjust latents
|
||||||
|
latents = expand_latents(
|
||||||
|
latents,
|
||||||
|
random_seed(generator),
|
||||||
|
Size(resize[1], resize[0]),
|
||||||
|
sigma=self.scheduler.init_noise_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
count.fill(0)
|
count.fill(0)
|
||||||
|
@ -886,6 +1046,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[
|
||||||
|
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||||
|
]
|
||||||
|
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
@ -1053,12 +1218,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_channels_latents = NUM_LATENT_CHANNELS
|
num_channels_latents = LATENT_CHANNELS
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
num_channels_latents,
|
num_channels_latents,
|
||||||
height // 8,
|
height // LATENT_FACTOR,
|
||||||
width // 8,
|
width // LATENT_FACTOR,
|
||||||
)
|
)
|
||||||
latents_dtype = prompt_embeds.dtype
|
latents_dtype = prompt_embeds.dtype
|
||||||
if latents is None:
|
if latents is None:
|
||||||
|
@ -1136,9 +1301,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
count = np.zeros_like(latents)
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
value = np.zeros_like(latents)
|
|
||||||
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
# adjust latents
|
||||||
|
latents = expand_latents(
|
||||||
|
latents,
|
||||||
|
random_seed(generator),
|
||||||
|
Size(resize[1], resize[0]),
|
||||||
|
sigma=self.scheduler.init_noise_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
count.fill(0)
|
count.fill(0)
|
||||||
|
@ -1201,6 +1376,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
callback(i, t, latents)
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[
|
||||||
|
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||||
|
]
|
||||||
|
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
|
|
@ -0,0 +1,955 @@
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from math import ceil
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||||
|
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
||||||
|
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
||||||
|
StableDiffusionXLImg2ImgPipelineMixin,
|
||||||
|
)
|
||||||
|
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
||||||
|
|
||||||
|
from ...chain.tile import make_tile_mask
|
||||||
|
from ...constants import LATENT_FACTOR
|
||||||
|
from ...params import Size
|
||||||
|
from ..utils import (
|
||||||
|
expand_latents,
|
||||||
|
parse_regions,
|
||||||
|
random_seed,
|
||||||
|
repair_nan,
|
||||||
|
resize_latent_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WINDOW = 64
|
||||||
|
DEFAULT_STRIDE = 16
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
window: int = DEFAULT_WINDOW,
|
||||||
|
stride: int = DEFAULT_STRIDE,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
self.window = window
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def set_window_size(self, window: int, stride: int):
|
||||||
|
self.window = window
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def get_views(
|
||||||
|
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
|
||||||
|
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
|
||||||
|
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||||
|
panorama_height /= 8
|
||||||
|
panorama_width /= 8
|
||||||
|
|
||||||
|
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
|
||||||
|
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
|
||||||
|
|
||||||
|
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||||
|
logger.debug(
|
||||||
|
"panorama generated %s views, %s by %s blocks",
|
||||||
|
total_num_blocks,
|
||||||
|
num_blocks_height,
|
||||||
|
num_blocks_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
views = []
|
||||||
|
for i in range(total_num_blocks):
|
||||||
|
h_start = int((i // num_blocks_width) * stride)
|
||||||
|
h_end = h_start + window_size
|
||||||
|
w_start = int((i % num_blocks_width) * stride)
|
||||||
|
w_end = w_start + window_size
|
||||||
|
views.append((h_start, h_end, w_start, w_end))
|
||||||
|
|
||||||
|
return (views, (h_end * 8, w_end * 8))
|
||||||
|
|
||||||
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
|
def prepare_latents_img2img(
|
||||||
|
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
|
||||||
|
):
|
||||||
|
batch_size = batch_size * num_images_per_prompt
|
||||||
|
|
||||||
|
if image.shape[1] == 4:
|
||||||
|
init_latents = image
|
||||||
|
else:
|
||||||
|
init_latents = self.vae_encoder(sample=image)[
|
||||||
|
0
|
||||||
|
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
|
|
||||||
|
if (
|
||||||
|
batch_size > init_latents.shape[0]
|
||||||
|
and batch_size % init_latents.shape[0] == 0
|
||||||
|
):
|
||||||
|
# expand init_latents for batch_size
|
||||||
|
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||||
|
init_latents = np.concatenate(
|
||||||
|
[init_latents] * additional_image_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
batch_size > init_latents.shape[0]
|
||||||
|
and batch_size % init_latents.shape[0] != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_latents = np.concatenate([init_latents], axis=0)
|
||||||
|
|
||||||
|
# add noise to latents using the timesteps
|
||||||
|
noise = generator.randn(*init_latents.shape).astype(dtype)
|
||||||
|
init_latents = self.scheduler.add_noise(
|
||||||
|
torch.from_numpy(init_latents),
|
||||||
|
torch.from_numpy(noise),
|
||||||
|
torch.from_numpy(timestep),
|
||||||
|
)
|
||||||
|
return init_latents.numpy()
|
||||||
|
|
||||||
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
|
def prepare_latents_text2img(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height // self.vae_scale_factor,
|
||||||
|
width // self.vae_scale_factor,
|
||||||
|
)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = generator.randn(*shape).astype(dtype)
|
||||||
|
elif latents.shape != shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||||
|
def text2img(
|
||||||
|
self,
|
||||||
|
prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[np.random.RandomState] = None,
|
||||||
|
latents: Optional[np.ndarray] = None,
|
||||||
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
output_type: str = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
original_size: Optional[Tuple[int, int]] = None,
|
||||||
|
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||||
|
target_size: Optional[Tuple[int, int]] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
height (`Optional[int]`, defaults to None):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`Optional[int]`, defaults to None):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, defaults to 5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`Optional[Union[str, list]]`):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||||
|
is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
|
||||||
|
A np.random.RandomState to make generation deterministic.
|
||||||
|
latents (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (Optional[Callable], defaults to `None`):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
guidance_rescale (`float`, defaults to 0.7):
|
||||||
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||||
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
|
(nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
|
||||||
|
|
||||||
|
original_size = original_size or (height, width)
|
||||||
|
target_size = target_size or (height, width)
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
1.0,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if generator is None:
|
||||||
|
generator = np.random
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
prompt, regions = parse_regions(prompt)
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3.b. Encode region prompts
|
||||||
|
region_embeds: List[np.ndarray] = []
|
||||||
|
add_region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
|
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
|
||||||
|
if region_prompt.endswith("+"):
|
||||||
|
region_prompt = region_prompt[:-1] + " " + prompt
|
||||||
|
|
||||||
|
(
|
||||||
|
region_prompt_embeds,
|
||||||
|
region_negative_prompt_embeds,
|
||||||
|
region_pooled_prompt_embeds,
|
||||||
|
region_negative_pooled_prompt_embeds,
|
||||||
|
) = self._encode_prompt(
|
||||||
|
region_prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
region_prompt_embeds = np.concatenate(
|
||||||
|
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_region_embeds.append(
|
||||||
|
np.concatenate(
|
||||||
|
(
|
||||||
|
region_negative_pooled_prompt_embeds,
|
||||||
|
region_pooled_prompt_embeds,
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
region_embeds.append(region_prompt_embeds)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
latents = self.prepare_latents_text2img(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
self.unet.config.get("in_channels", 4),
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Prepare added time ids & embeddings
|
||||||
|
add_text_embeds = pooled_prompt_embeds
|
||||||
|
add_time_ids = (original_size + crops_coords_top_left + target_size,)
|
||||||
|
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = np.concatenate(
|
||||||
|
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_text_embeds = np.concatenate(
|
||||||
|
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||||
|
|
||||||
|
add_time_ids = np.repeat(
|
||||||
|
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adapted from diffusers to extend it for other runtimes than ORT
|
||||||
|
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||||
|
|
||||||
|
# 8. Panorama additions
|
||||||
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
# adjust latents
|
||||||
|
latents = expand_latents(
|
||||||
|
latents,
|
||||||
|
random_seed(generator),
|
||||||
|
Size(resize[1], resize[0]),
|
||||||
|
sigma=self.scheduler.init_noise_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
last = i == (len(timesteps) - 1)
|
||||||
|
count.fill(0)
|
||||||
|
value.fill(0)
|
||||||
|
|
||||||
|
for h_start, h_end, w_start, w_end in views:
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = (
|
||||||
|
np.concatenate([latents_for_view] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_view
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_model_input), t
|
||||||
|
)
|
||||||
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
noise_pred = self.unet(
|
||||||
|
sample=latent_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
text_embeds=add_text_embeds,
|
||||||
|
time_ids=add_time_ids,
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred,
|
||||||
|
noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_view),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
)
|
||||||
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
|
if not last:
|
||||||
|
for r, region in enumerate(regions):
|
||||||
|
top, left, bottom, right, weight, feather, prompt = region
|
||||||
|
logger.debug(
|
||||||
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
|
top,
|
||||||
|
left,
|
||||||
|
bottom,
|
||||||
|
right,
|
||||||
|
weight,
|
||||||
|
feather,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert coordinates to latent space
|
||||||
|
h_start = top // LATENT_FACTOR
|
||||||
|
h_end = bottom // LATENT_FACTOR
|
||||||
|
w_start = left // LATENT_FACTOR
|
||||||
|
w_end = right // LATENT_FACTOR
|
||||||
|
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
logger.trace(
|
||||||
|
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||||
|
h_start,
|
||||||
|
h_end,
|
||||||
|
w_start,
|
||||||
|
w_end,
|
||||||
|
latents_for_region.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_region_input = (
|
||||||
|
np.concatenate([latents_for_region] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_region
|
||||||
|
)
|
||||||
|
latent_region_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_region_input), t
|
||||||
|
)
|
||||||
|
latent_region_input = latent_region_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
region_noise_pred = self.unet(
|
||||||
|
sample=latent_region_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=region_embeds[r],
|
||||||
|
text_embeds=add_region_embeds[r],
|
||||||
|
time_ids=add_time_ids,
|
||||||
|
)
|
||||||
|
region_noise_pred = region_noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||||
|
region_noise_pred, 2
|
||||||
|
)
|
||||||
|
region_noise_pred = (
|
||||||
|
region_noise_pred_uncond
|
||||||
|
+ guidance_scale
|
||||||
|
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||||
|
)
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
region_noise_pred = rescale_noise_cfg(
|
||||||
|
region_noise_pred,
|
||||||
|
region_noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(region_noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_region),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
)
|
||||||
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
if feather[0] > 0.0:
|
||||||
|
mask = make_tile_mask(
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
feather[0],
|
||||||
|
feather[1],
|
||||||
|
)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
else:
|
||||||
|
mask = 1
|
||||||
|
|
||||||
|
if weight >= 100.0:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||||
|
latents_region_denoised * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
|
else:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
|
latents_region_denoised * weight * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
latents = repair_nan(latents)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[
|
||||||
|
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||||
|
]
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = np.clip(latents, -4, +4)
|
||||||
|
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
image = np.concatenate(
|
||||||
|
[
|
||||||
|
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||||
|
for i in range(latents.shape[0])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image = self.watermark.apply_watermark(image)
|
||||||
|
|
||||||
|
# TODO: add image_processor
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||||
|
def img2img(
|
||||||
|
self,
|
||||||
|
prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||||
|
strength: float = 0.3,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[np.random.RandomState] = None,
|
||||||
|
latents: Optional[np.ndarray] = None,
|
||||||
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
output_type: str = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
original_size: Optional[Tuple[int, int]] = None,
|
||||||
|
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||||
|
target_size: Optional[Tuple[int, int]] = None,
|
||||||
|
aesthetic_score: float = 6.0,
|
||||||
|
negative_aesthetic_score: float = 2.5,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
image (`Union[np.ndarray, PIL.Image.Image]`):
|
||||||
|
`Image`, or tensor representing an image batch which will be upscaled.
|
||||||
|
strength (`float`, defaults to 0.8):
|
||||||
|
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||||
|
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||||
|
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||||
|
be maximum and the denoising process will run for the full number of iterations specified in
|
||||||
|
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||||
|
num_inference_steps (`int`, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, defaults to 5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`Optional[Union[str, list]]`):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||||
|
is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
|
||||||
|
A np.random.RandomState to make generation deterministic.
|
||||||
|
latents (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (Optional[Callable], defaults to `None`):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
guidance_rescale (`float`, defaults to 0.7):
|
||||||
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||||
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
|
(nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
# 0. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
strength,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Define call parameters
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if generator is None:
|
||||||
|
generator = np.random
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 2. Encode input prompt
|
||||||
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Preprocess image
|
||||||
|
processor = VaeImageProcessor()
|
||||||
|
image = processor.preprocess(image).cpu().numpy()
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
timesteps, num_inference_steps = self.get_timesteps(
|
||||||
|
num_inference_steps, strength
|
||||||
|
)
|
||||||
|
latent_timestep = np.repeat(
|
||||||
|
timesteps[:1], batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||||
|
|
||||||
|
latents_dtype = prompt_embeds.dtype
|
||||||
|
image = image.astype(latents_dtype)
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
latents = self.prepare_latents_img2img(
|
||||||
|
image,
|
||||||
|
latent_timestep,
|
||||||
|
batch_size,
|
||||||
|
num_images_per_prompt,
|
||||||
|
latents_dtype,
|
||||||
|
generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys()
|
||||||
|
)
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
height, width = latents.shape[-2:]
|
||||||
|
height = height * self.vae_scale_factor
|
||||||
|
width = width * self.vae_scale_factor
|
||||||
|
original_size = original_size or (height, width)
|
||||||
|
target_size = target_size or (height, width)
|
||||||
|
|
||||||
|
# 8. Prepare added time ids & embeddings
|
||||||
|
add_text_embeds = pooled_prompt_embeds
|
||||||
|
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||||
|
original_size,
|
||||||
|
crops_coords_top_left,
|
||||||
|
target_size,
|
||||||
|
aesthetic_score,
|
||||||
|
negative_aesthetic_score,
|
||||||
|
dtype=prompt_embeds.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = np.concatenate(
|
||||||
|
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_text_embeds = np.concatenate(
|
||||||
|
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||||
|
)
|
||||||
|
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||||
|
add_time_ids = np.repeat(
|
||||||
|
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Panorama additions
|
||||||
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
latents = expand_latents(
|
||||||
|
latents,
|
||||||
|
random_seed(generator),
|
||||||
|
Size(resize[1], resize[0]),
|
||||||
|
sigma=self.scheduler.init_noise_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
count.fill(0)
|
||||||
|
value.fill(0)
|
||||||
|
|
||||||
|
for h_start, h_end, w_start, w_end in views:
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = (
|
||||||
|
np.concatenate([latents_for_view] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_view
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_model_input), t
|
||||||
|
)
|
||||||
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
noise_pred = self.unet(
|
||||||
|
sample=latent_model_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
text_embeds=add_text_embeds,
|
||||||
|
time_ids=add_time_ids,
|
||||||
|
)
|
||||||
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred,
|
||||||
|
noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_view),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
)
|
||||||
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# remove extra margins
|
||||||
|
latents = latents[
|
||||||
|
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||||
|
]
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
else:
|
||||||
|
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||||
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
image = np.concatenate(
|
||||||
|
[
|
||||||
|
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||||
|
for i in range(latents.shape[0])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image = self.watermark.apply_watermark(image)
|
||||||
|
|
||||||
|
# TODO: add image_processor
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
return StableDiffusionXLPipelineOutput(images=image)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if "image" in kwargs or (
|
||||||
|
len(args) > 1
|
||||||
|
and (
|
||||||
|
isinstance(args[1], np.ndarray) or isinstance(args[1], PIL.Image.Image)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
logger.debug("running img2img panorama XL pipeline")
|
||||||
|
return self.img2img(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
logger.debug("running txt2img panorama XL pipeline")
|
||||||
|
return self.text2img(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ORTStableDiffusionXLPanoramaPipeline(
|
||||||
|
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
|
||||||
|
):
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)
|
|
@ -32,7 +32,7 @@ except ImportError:
|
||||||
}
|
}
|
||||||
|
|
||||||
from diffusers import OnnxRuntimeModel
|
from diffusers import OnnxRuntimeModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.schedulers import (
|
from diffusers.schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
|
|
@ -1,63 +1,25 @@
|
||||||
###
|
|
||||||
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
|
|
||||||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
|
|
||||||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
|
|
||||||
# See also: https://github.com/huggingface/diffusers/pull/2158
|
|
||||||
###
|
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, List
|
||||||
|
|
||||||
import numpy as np
|
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
|
||||||
import PIL
|
from diffusers.pipelines.stable_diffusion import (
|
||||||
import torch
|
OnnxStableDiffusionUpscalePipeline as BasePipeline,
|
||||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
)
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
|
||||||
from diffusers.schedulers import DDPMScheduler
|
from diffusers.schedulers import DDPMScheduler
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
NUM_LATENT_CHANNELS = 4
|
|
||||||
NUM_UNET_INPUT_CHANNELS = 7
|
|
||||||
|
|
||||||
ORT_TO_PT_TYPE = {
|
|
||||||
"float16": torch.float16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
|
||||||
if isinstance(image, torch.Tensor):
|
|
||||||
return image
|
|
||||||
elif isinstance(image, PIL.Image.Image):
|
|
||||||
image = [image]
|
|
||||||
|
|
||||||
if isinstance(image[0], PIL.Image.Image):
|
|
||||||
w, h = image[0].size
|
|
||||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
|
||||||
|
|
||||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
|
||||||
image = np.concatenate(image, axis=0)
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
image = image.transpose(0, 3, 1, 2)
|
|
||||||
image = 2.0 * image - 1.0
|
|
||||||
image = torch.from_numpy(image)
|
|
||||||
elif isinstance(image[0], torch.Tensor):
|
|
||||||
image = torch.cat(image, dim=0)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class FakeConfig:
|
class FakeConfig:
|
||||||
|
block_out_channels: List[int]
|
||||||
scaling_factor: float
|
scaling_factor: float
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self.block_out_channels = [128, 256, 512]
|
||||||
self.scaling_factor = 0.08333
|
self.scaling_factor = 0.08333
|
||||||
|
|
||||||
|
|
||||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vae: OnnxRuntimeModel,
|
vae: OnnxRuntimeModel,
|
||||||
|
@ -80,260 +42,3 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
scheduler,
|
scheduler,
|
||||||
max_noise_level=max_noise_level,
|
max_noise_level=max_noise_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]],
|
|
||||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
|
||||||
num_inference_steps: int = 75,
|
|
||||||
guidance_scale: float = 9.0,
|
|
||||||
noise_level: int = 20,
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
|
||||||
eta: float = 0.0,
|
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
output_type: Optional[str] = "pil",
|
|
||||||
return_dict: bool = True,
|
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
||||||
callback_steps: Optional[int] = 1,
|
|
||||||
):
|
|
||||||
# 1. Check inputs
|
|
||||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
|
||||||
|
|
||||||
# 2. Define call parameters
|
|
||||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
|
||||||
device = self._execution_device
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
||||||
# corresponds to doing no classifier free guidance.
|
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
|
||||||
|
|
||||||
# 3. Encode input prompt
|
|
||||||
text_embeddings = self._encode_prompt(
|
|
||||||
prompt,
|
|
||||||
# device, device only needed for Torch pipelines
|
|
||||||
num_images_per_prompt,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
negative_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
|
||||||
|
|
||||||
# 4. Preprocess image
|
|
||||||
image = preprocess(image)
|
|
||||||
image = image.cpu()
|
|
||||||
|
|
||||||
# 5. set timesteps
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
# 5. Add noise to image
|
|
||||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
|
||||||
noise = torch.randn(
|
|
||||||
image.shape, generator=generator, device=device, dtype=latents_dtype
|
|
||||||
)
|
|
||||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
|
||||||
|
|
||||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
|
||||||
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
|
|
||||||
noise_level = np.concatenate([noise_level] * image.shape[0])
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
height, width = image.shape[2:]
|
|
||||||
latents = self.prepare_latents(
|
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
NUM_LATENT_CHANNELS,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
latents_dtype,
|
|
||||||
device,
|
|
||||||
generator,
|
|
||||||
latents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7. Check that sizes of image and latents match
|
|
||||||
num_channels_image = image.shape[1]
|
|
||||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
|
||||||
raise ValueError(
|
|
||||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
|
||||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
|
||||||
f" `num_channels_image`: {num_channels_image} "
|
|
||||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
|
||||||
" `pipeline.unet` or your `image` input."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
||||||
|
|
||||||
timestep_dtype = next(
|
|
||||||
(
|
|
||||||
input.type
|
|
||||||
for input in self.unet.model.get_inputs()
|
|
||||||
if input.name == "timestep"
|
|
||||||
),
|
|
||||||
"tensor(float)",
|
|
||||||
)
|
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
|
||||||
|
|
||||||
# 9. Denoising loop
|
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
||||||
for i, t in enumerate(timesteps):
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = (
|
|
||||||
np.concatenate([latents] * 2)
|
|
||||||
if do_classifier_free_guidance
|
|
||||||
else latents
|
|
||||||
)
|
|
||||||
|
|
||||||
# concat latents, mask, masked_image_latents in the channel dimension
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(
|
|
||||||
latent_model_input, t
|
|
||||||
)
|
|
||||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
|
||||||
|
|
||||||
# timestep to tensor
|
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
noise_pred = self.unet(
|
|
||||||
sample=latent_model_input,
|
|
||||||
timestep=timestep,
|
|
||||||
encoder_hidden_states=text_embeddings,
|
|
||||||
class_labels=noise_level.astype(np.int64),
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
||||||
noise_pred_text - noise_pred_uncond
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents = self.scheduler.step(
|
|
||||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
|
||||||
).prev_sample
|
|
||||||
|
|
||||||
# call the callback, if provided
|
|
||||||
if i == len(timesteps) - 1 or (
|
|
||||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
|
||||||
):
|
|
||||||
progress_bar.update()
|
|
||||||
if callback is not None and i % callback_steps == 0:
|
|
||||||
callback(i, t, latents)
|
|
||||||
|
|
||||||
# 10. Post-processing
|
|
||||||
image = self.decode_latents(latents.float())
|
|
||||||
|
|
||||||
# 11. Convert to PIL
|
|
||||||
if output_type == "pil":
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (image,)
|
|
||||||
|
|
||||||
return ImagePipelineOutput(images=image)
|
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
|
||||||
latents = 1 / 0.08333 * latents
|
|
||||||
image = self.vae(latent_sample=latents)[0]
|
|
||||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
||||||
image = image.transpose((0, 2, 3, 1))
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _encode_prompt(
|
|
||||||
self,
|
|
||||||
prompt,
|
|
||||||
device,
|
|
||||||
num_images_per_prompt,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
negative_prompt,
|
|
||||||
):
|
|
||||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
|
||||||
|
|
||||||
text_inputs = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=self.tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
untruncated_ids = self.tokenizer(
|
|
||||||
prompt, padding="longest", return_tensors="pt"
|
|
||||||
).input_ids
|
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
||||||
text_input_ids, untruncated_ids
|
|
||||||
):
|
|
||||||
removed_text = self.tokenizer.batch_decode(
|
|
||||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# no positional arguments to text_encoder
|
|
||||||
text_embeddings = self.text_encoder(
|
|
||||||
input_ids=text_input_ids.int().to(device),
|
|
||||||
)
|
|
||||||
text_embeddings = text_embeddings[0]
|
|
||||||
|
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
|
|
||||||
text_embeddings = text_embeddings.reshape(
|
|
||||||
bs_embed * num_images_per_prompt, seq_len, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
uncond_tokens: List[str]
|
|
||||||
if negative_prompt is None:
|
|
||||||
uncond_tokens = [""] * batch_size
|
|
||||||
elif type(prompt) is not type(negative_prompt):
|
|
||||||
raise TypeError(
|
|
||||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
||||||
f" {type(prompt)}."
|
|
||||||
)
|
|
||||||
elif isinstance(negative_prompt, str):
|
|
||||||
uncond_tokens = [negative_prompt]
|
|
||||||
elif batch_size != len(negative_prompt):
|
|
||||||
raise ValueError(
|
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
||||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
||||||
" the batch size of `prompt`."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
uncond_tokens = negative_prompt
|
|
||||||
|
|
||||||
max_length = text_input_ids.shape[-1]
|
|
||||||
uncond_input = self.tokenizer(
|
|
||||||
uncond_tokens,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
uncond_embeddings = self.text_encoder(
|
|
||||||
input_ids=uncond_input.input_ids.int().to(device),
|
|
||||||
)
|
|
||||||
uncond_embeddings = uncond_embeddings[0]
|
|
||||||
|
|
||||||
seq_len = uncond_embeddings.shape[1]
|
|
||||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
|
||||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
|
||||||
uncond_embeddings = uncond_embeddings.reshape(
|
|
||||||
batch_size * num_images_per_prompt, seq_len, -1
|
|
||||||
)
|
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
||||||
# to avoid doing two forward passes
|
|
||||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
|
||||||
|
|
||||||
return text_embeddings
|
|
||||||
|
|
|
@ -4,15 +4,16 @@ from typing import Any, List, Optional
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from onnx_web.chain.highres import stage_highres
|
|
||||||
|
|
||||||
from ..chain import (
|
from ..chain import (
|
||||||
|
BlendDenoiseStage,
|
||||||
BlendImg2ImgStage,
|
BlendImg2ImgStage,
|
||||||
BlendMaskStage,
|
BlendMaskStage,
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
SourceTxt2ImgStage,
|
SourceTxt2ImgStage,
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
|
from ..chain.highres import stage_highres
|
||||||
|
from ..chain.result import StageResult
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
|
@ -33,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_tile(params: ImageParams, size: Size) -> int:
|
||||||
|
if params.is_panorama():
|
||||||
|
tile = max(params.unet_tile, size.width, size.height)
|
||||||
|
logger.debug("adjusting tile size for panorama to %s", tile)
|
||||||
|
return tile
|
||||||
|
|
||||||
|
return params.unet_tile
|
||||||
|
|
||||||
|
|
||||||
|
def get_highres_tile(
|
||||||
|
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
|
||||||
|
) -> int:
|
||||||
|
if params.is_panorama() and server.has_feature("panorama-highres"):
|
||||||
|
return tile * highres.scale
|
||||||
|
|
||||||
|
return params.unet_tile
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -43,10 +62,7 @@ def run_txt2img_pipeline(
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
# if using panorama, the pipeline will tile itself (views)
|
# if using panorama, the pipeline will tile itself (views)
|
||||||
if params.is_panorama() or params.is_xl():
|
tile_size = get_base_tile(params, size)
|
||||||
tile_size = max(params.tiles, size.width, size.height)
|
|
||||||
else:
|
|
||||||
tile_size = params.tiles
|
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
|
@ -57,15 +73,21 @@ def run_txt2img_pipeline(
|
||||||
),
|
),
|
||||||
size=size,
|
size=size,
|
||||||
prompt_index=0,
|
prompt_index=0,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
stage = StageParams(tile_size=params.tiles)
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
|
if params.is_panorama():
|
||||||
|
chain.stage(
|
||||||
|
BlendDenoiseStage(),
|
||||||
|
StageParams(tile_size=highres_size),
|
||||||
|
)
|
||||||
|
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
|
@ -73,7 +95,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -83,7 +105,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
|
@ -92,11 +114,14 @@ def run_txt2img_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
logger.trace("saving output image %s: %s", output, image.size)
|
||||||
dest = save_image(
|
dest = save_image(
|
||||||
server,
|
server,
|
||||||
output,
|
output,
|
||||||
|
@ -136,23 +161,26 @@ def run_img2img_pipeline(
|
||||||
source = f(server, source)
|
source = f(server, source)
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
|
tile_size = get_base_tile(params, Size(*source.size))
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(
|
|
||||||
tile_size=params.tiles,
|
|
||||||
)
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
StageParams(
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
prompt_index=0,
|
prompt_index=0,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(
|
||||||
|
outscale=first_upscale.outscale,
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -162,13 +190,16 @@ def run_img2img_pipeline(
|
||||||
for _i in range(params.loopback):
|
for _i in range(params.loopback):
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
StageParams(
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
# highres, if selected
|
# highres, if selected
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(tile_size=highres_size, outscale=highres.scale),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -178,7 +209,7 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(tile_size=tile_size, outscale=after_upscale.scale),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -186,7 +217,9 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.append(source)
|
images.append(source)
|
||||||
|
@ -235,7 +268,7 @@ def run_inpaint_pipeline(
|
||||||
full_res_inpaint_padding: float,
|
full_res_inpaint_padding: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("building inpaint pipeline")
|
logger.debug("building inpaint pipeline")
|
||||||
tile_size = params.tiles
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
|
@ -264,8 +297,12 @@ def run_inpaint_pipeline(
|
||||||
logger.debug("border zero: %s", border.isZero())
|
logger.debug("border zero: %s", border.isZero())
|
||||||
full_res_inpaint = full_res_inpaint and border.isZero()
|
full_res_inpaint = full_res_inpaint and border.isZero()
|
||||||
if full_res_inpaint:
|
if full_res_inpaint:
|
||||||
mask_left, mask_top, mask_right, mask_bottom = mask.getbbox()
|
bbox = mask.getbbox()
|
||||||
logger.debug("mask bbox: %s", mask.getbbox())
|
if bbox is None:
|
||||||
|
bbox = (0, 0, source.width, source.height)
|
||||||
|
|
||||||
|
logger.debug("mask bounding box: %s", bbox)
|
||||||
|
mask_left, mask_top, mask_right, mask_bottom = bbox
|
||||||
mask_width = mask_right - mask_left
|
mask_width = mask_right - mask_left
|
||||||
mask_height = mask_bottom - mask_top
|
mask_height = mask_bottom - mask_top
|
||||||
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
|
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
|
||||||
|
@ -322,16 +359,15 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
UpscaleOutpaintStage(),
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
StageParams(tile_order=tile_order, tile_size=tile_size),
|
||||||
border=border,
|
border=border,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
fill_color=fill_color,
|
fill_color=fill_color,
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
overlap=params.overlap,
|
overlap=params.vae_overlap,
|
||||||
prompt_index=0,
|
prompt_index=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,15 +375,16 @@ def run_inpaint_pipeline(
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -357,7 +394,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=after_upscale.outscale),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -366,7 +403,14 @@ def run_inpaint_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress, latents=latents)
|
images = chain.run(
|
||||||
|
worker,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
StageResult(images=[source]),
|
||||||
|
callback=progress,
|
||||||
|
latents=latents,
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -409,21 +453,22 @@ def run_upscale_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline, no base stage for upscaling
|
# set up the chain pipeline, no base stage for upscaling
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_size=params.tiles)
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -433,7 +478,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -441,7 +486,9 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -478,12 +525,18 @@ def run_blend_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
tile_size = get_base_tile(params, size)
|
||||||
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
|
|
||||||
|
chain.stage(
|
||||||
|
BlendMaskStage(),
|
||||||
|
StageParams(tile_size=tile_size),
|
||||||
|
stage_source=sources[1],
|
||||||
|
stage_mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=upscale.outscale),
|
||||||
params,
|
params,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -491,7 +544,9 @@ def run_blend_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, sources, callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=sources), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||||
|
|
|
@ -3,23 +3,27 @@ from copy import deepcopy
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import Pattern, compile
|
from re import Pattern, compile
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
|
||||||
|
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
|
||||||
from ..params import ImageParams, Size
|
from ..params import ImageParams, Size
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
LATENT_CHANNELS = 4
|
|
||||||
LATENT_FACTOR = 8
|
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
|
|
||||||
|
ANY_TOKEN = compile(r"\<([^\>]*)\>")
|
||||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
|
REGION_TOKEN = compile(
|
||||||
|
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
|
||||||
|
)
|
||||||
|
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
|
@ -84,8 +88,8 @@ def expand_prompt(
|
||||||
negative_prompt: Optional[str] = None,
|
negative_prompt: Optional[str] = None,
|
||||||
prompt_embeds: Optional[np.ndarray] = None,
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
skip_clip_states: Optional[int] = 0,
|
skip_clip_states: int = 0,
|
||||||
) -> "np.NDArray":
|
) -> np.ndarray:
|
||||||
# self provides:
|
# self provides:
|
||||||
# tokenizer: CLIPTokenizer
|
# tokenizer: CLIPTokenizer
|
||||||
# encoder: OnnxRuntimeModel
|
# encoder: OnnxRuntimeModel
|
||||||
|
@ -140,6 +144,7 @@ def expand_prompt(
|
||||||
|
|
||||||
last_state, _pooled_output, *hidden_states = text_result
|
last_state, _pooled_output, *hidden_states = text_result
|
||||||
if skip_clip_states > 0:
|
if skip_clip_states > 0:
|
||||||
|
# TODO: why is this normalized?
|
||||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||||
norm_state = layer_norm(
|
norm_state = layer_norm(
|
||||||
torch.from_numpy(
|
torch.from_numpy(
|
||||||
|
@ -219,20 +224,25 @@ def expand_prompt(
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
|
||||||
|
name, weight = group
|
||||||
|
return (name, float(weight))
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_from_prompt(
|
def get_tokens_from_prompt(
|
||||||
prompt: str, pattern: Pattern
|
prompt: str,
|
||||||
|
pattern: Pattern,
|
||||||
|
parser=parse_float_group,
|
||||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
"""
|
|
||||||
TODO: replace with Arpeggio
|
|
||||||
"""
|
|
||||||
remaining_prompt = prompt
|
remaining_prompt = prompt
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
next_match = pattern.search(remaining_prompt)
|
next_match = pattern.search(remaining_prompt)
|
||||||
while next_match is not None:
|
while next_match is not None:
|
||||||
logger.debug("found token in prompt: %s", next_match)
|
logger.debug("found token in prompt: %s", next_match)
|
||||||
name, weight = next_match.groups()
|
group = next_match.groups()
|
||||||
tokens.append((name, float(weight)))
|
tokens.append(parser(group))
|
||||||
|
|
||||||
# remove this match and look for another
|
# remove this match and look for another
|
||||||
remaining_prompt = (
|
remaining_prompt = (
|
||||||
remaining_prompt[: next_match.start()]
|
remaining_prompt[: next_match.start()]
|
||||||
|
@ -251,6 +261,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]
|
||||||
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
|
def random_seed(generator=None) -> int:
|
||||||
|
if generator is None:
|
||||||
|
generator = np.random
|
||||||
|
|
||||||
|
return generator.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
|
|
||||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
From https://www.travelneil.com/stable-diffusion-updates.html.
|
From https://www.travelneil.com/stable-diffusion-updates.html.
|
||||||
|
@ -266,6 +283,25 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
|
def expand_latents(
|
||||||
|
latents: np.ndarray,
|
||||||
|
seed: int,
|
||||||
|
size: Size,
|
||||||
|
sigma: float = 1.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
batch, _channels, height, width = latents.shape
|
||||||
|
extra_latents = get_latents_from_seed(seed, size, batch=batch)
|
||||||
|
extra_latents[:, :, 0:height, 0:width] = latents
|
||||||
|
return extra_latents * np.float64(sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_latent_shape(
|
||||||
|
latents: np.ndarray,
|
||||||
|
size: Tuple[int, int],
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
return (latents.shape[0], latents.shape[1], *size)
|
||||||
|
|
||||||
|
|
||||||
def get_tile_latents(
|
def get_tile_latents(
|
||||||
full_latents: np.ndarray,
|
full_latents: np.ndarray,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
@ -290,14 +326,8 @@ def get_tile_latents(
|
||||||
|
|
||||||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
if tile_latents.shape != full_latents.shape and (
|
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||||
tile_latents.shape[2] < t or tile_latents.shape[3] < t
|
tile_latents = expand_latents(tile_latents, seed, size)
|
||||||
):
|
|
||||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
|
||||||
extra_latents[
|
|
||||||
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
|
||||||
] = tile_latents
|
|
||||||
tile_latents = extra_latents
|
|
||||||
|
|
||||||
return tile_latents
|
return tile_latents
|
||||||
|
|
||||||
|
@ -369,12 +399,15 @@ def encode_prompt(
|
||||||
num_images_per_prompt: int = 1,
|
num_images_per_prompt: int = 1,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
pipe._encode_prompt(
|
pipe._encode_prompt(
|
||||||
prompt,
|
remove_tokens(prompt),
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
negative_prompt=neg_prompt,
|
negative_prompt=remove_tokens(neg_prompt),
|
||||||
)
|
)
|
||||||
for prompt, neg_prompt in prompt_pairs
|
for prompt, neg_prompt in prompt_pairs
|
||||||
]
|
]
|
||||||
|
@ -444,3 +477,71 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
||||||
return parts[min(slice, len(parts) - 1)]
|
return parts[min(slice, len(parts) - 1)]
|
||||||
else:
|
else:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
Region = Tuple[
|
||||||
|
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_region_group(group: Tuple[str, ...]) -> Region:
|
||||||
|
top, left, bottom, right, weight, feather, prompt = group
|
||||||
|
|
||||||
|
# break down the feather section
|
||||||
|
feather_radius, *feather_edges = feather.split("_")
|
||||||
|
if len(feather_edges) == 0:
|
||||||
|
feather_edges = "TLBR"
|
||||||
|
else:
|
||||||
|
feather_edges = "".join(feather_edges)
|
||||||
|
|
||||||
|
return (
|
||||||
|
int(top),
|
||||||
|
int(left),
|
||||||
|
int(bottom),
|
||||||
|
int(right),
|
||||||
|
float(weight),
|
||||||
|
(
|
||||||
|
float(feather_radius),
|
||||||
|
(
|
||||||
|
"T" in feather_edges,
|
||||||
|
"L" in feather_edges,
|
||||||
|
"B" in feather_edges,
|
||||||
|
"R" in feather_edges,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||||
|
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
|
||||||
|
|
||||||
|
|
||||||
|
Reseed = Tuple[int, int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_reseed_group(group) -> Region:
|
||||||
|
top, left, bottom, right, seed = group
|
||||||
|
return (
|
||||||
|
int(top),
|
||||||
|
int(left),
|
||||||
|
int(bottom),
|
||||||
|
int(right),
|
||||||
|
int(seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
||||||
|
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
||||||
|
|
||||||
|
|
||||||
|
def skip_group(group) -> Any:
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tokens(prompt: Optional[str]) -> Optional[str]:
|
||||||
|
if prompt is None:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group)
|
||||||
|
return remainder
|
||||||
|
|
|
@ -12,6 +12,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import LCMScheduler
|
||||||
|
except ImportError:
|
||||||
|
from ..diffusers.stub_scheduler import StubScheduler as LCMScheduler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import UniPCMultistepScheduler
|
from diffusers import UniPCMultistepScheduler
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -8,7 +8,7 @@ def mask_filter_none(
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new("RGB", (width, height), fill)
|
noise = Image.new(mask.mode, (width, height), fill)
|
||||||
noise.paste(mask, origin)
|
noise.paste(mask, origin)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
|
@ -17,21 +17,21 @@ def noise_source_fill_edge(
|
||||||
"""
|
"""
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new("RGB", (width, height), fill)
|
noise = Image.new(source.mode, (width, height), fill)
|
||||||
noise.paste(source, origin)
|
noise.paste(source, origin)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
def noise_source_fill_mask(
|
def noise_source_fill_mask(
|
||||||
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
Fill the whole canvas, no source or noise.
|
Fill the whole canvas, no source or noise.
|
||||||
"""
|
"""
|
||||||
width, height = dims
|
width, height = dims
|
||||||
|
|
||||||
noise = Image.new("RGB", (width, height), fill)
|
noise = Image.new(source.mode, (width, height), fill)
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ def noise_source_gaussian(
|
||||||
|
|
||||||
|
|
||||||
def noise_source_uniform(
|
def noise_source_uniform(
|
||||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
@ -61,6 +61,7 @@ def noise_source_uniform(
|
||||||
noise_g = random.uniform(0, 256, size=size)
|
noise_g = random.uniform(0, 256, size=size)
|
||||||
noise_b = random.uniform(0, 256, size=size)
|
noise_b = random.uniform(0, 256, size=size)
|
||||||
|
|
||||||
|
# needs to be RGB for pixel manipulation
|
||||||
noise = Image.new("RGB", (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
|
@ -68,11 +69,11 @@ def noise_source_uniform(
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||||
|
|
||||||
return noise
|
return noise.convert(source.mode)
|
||||||
|
|
||||||
|
|
||||||
def noise_source_normal(
|
def noise_source_normal(
|
||||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
@ -81,6 +82,7 @@ def noise_source_normal(
|
||||||
noise_g = random.normal(128, 32, size=size)
|
noise_g = random.normal(128, 32, size=size)
|
||||||
noise_b = random.normal(128, 32, size=size)
|
noise_b = random.normal(128, 32, size=size)
|
||||||
|
|
||||||
|
# needs to be RGB for pixel manipulation
|
||||||
noise = Image.new("RGB", (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
|
@ -88,13 +90,13 @@ def noise_source_normal(
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||||
|
|
||||||
return noise
|
return noise.convert(source.mode)
|
||||||
|
|
||||||
|
|
||||||
def noise_source_histogram(
|
def noise_source_histogram(
|
||||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
r, g, b = source.split()
|
r, g, b, *_a = source.split()
|
||||||
width, height = dims
|
width, height = dims
|
||||||
size = width * height
|
size = width * height
|
||||||
|
|
||||||
|
@ -112,6 +114,7 @@ def noise_source_histogram(
|
||||||
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# needs to be RGB for pixel manipulation
|
||||||
noise = Image.new("RGB", (width, height))
|
noise = Image.new("RGB", (width, height))
|
||||||
|
|
||||||
for x in range(width):
|
for x in range(width):
|
||||||
|
@ -119,4 +122,4 @@ def noise_source_histogram(
|
||||||
i = get_pixel_index(x, y, width)
|
i = get_pixel_index(x, y, width)
|
||||||
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||||
|
|
||||||
return noise
|
return noise.convert(source.mode)
|
||||||
|
|
|
@ -47,7 +47,7 @@ def source_filter_noise(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
strength: float = 0.5,
|
strength: float = 0.5,
|
||||||
):
|
):
|
||||||
noise = noise_source_histogram(source, source.size)
|
noise = noise_source_histogram(source, source.size, (0, 0))
|
||||||
return ImageChops.blend(source, noise, strength)
|
return ImageChops.blend(source, noise, strength)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from ..params import Border, Size
|
from ..params import Border, Size
|
||||||
|
@ -13,12 +15,12 @@ def expand_image(
|
||||||
fill="white",
|
fill="white",
|
||||||
noise_source=noise_source_histogram,
|
noise_source=noise_source_histogram,
|
||||||
mask_filter=mask_filter_none,
|
mask_filter=mask_filter_none,
|
||||||
):
|
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
|
||||||
size = Size(*source.size).add_border(expand)
|
size = Size(*source.size).add_border(expand)
|
||||||
size = tuple(size)
|
size = tuple(size)
|
||||||
origin = (expand.left, expand.top)
|
origin = (expand.left, expand.top)
|
||||||
|
|
||||||
full_source = Image.new("RGB", size, fill)
|
full_source = Image.new(source.mode, size, fill)
|
||||||
full_source.paste(source, origin)
|
full_source.paste(source, origin)
|
||||||
|
|
||||||
# new mask pixels need to be filled with white so they will be replaced
|
# new mask pixels need to be filled with white so they will be replaced
|
||||||
|
|
|
@ -23,6 +23,7 @@ from .server.load import (
|
||||||
load_platforms,
|
load_platforms,
|
||||||
load_wildcards,
|
load_wildcards,
|
||||||
)
|
)
|
||||||
|
from .server.plugin import load_plugins, register_plugins
|
||||||
from .server.static import register_static_routes
|
from .server.static import register_static_routes
|
||||||
from .server.utils import check_paths
|
from .server.utils import check_paths
|
||||||
from .utils import is_debug
|
from .utils import is_debug
|
||||||
|
@ -43,15 +44,32 @@ def main():
|
||||||
server = ServerContext.from_environ()
|
server = ServerContext.from_environ()
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
check_paths(server)
|
check_paths(server)
|
||||||
|
|
||||||
|
# debug options
|
||||||
|
if server.debug:
|
||||||
|
import debugpy
|
||||||
|
|
||||||
|
debugpy.listen(5678)
|
||||||
|
logger.warning("waiting for debugger")
|
||||||
|
debugpy.wait_for_client()
|
||||||
|
gc.set_debug(gc.DEBUG_STATS)
|
||||||
|
|
||||||
|
# register plugins
|
||||||
|
exports = load_plugins(server)
|
||||||
|
success = register_plugins(exports)
|
||||||
|
if success:
|
||||||
|
logger.info("all plugins loaded successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("error loading plugins")
|
||||||
|
|
||||||
|
# load additional resources
|
||||||
load_extras(server)
|
load_extras(server)
|
||||||
load_models(server)
|
load_models(server)
|
||||||
load_params(server)
|
load_params(server)
|
||||||
load_platforms(server)
|
load_platforms(server)
|
||||||
load_wildcards(server)
|
load_wildcards(server)
|
||||||
|
|
||||||
if is_debug():
|
# misc server options
|
||||||
gc.set_debug(gc.DEBUG_STATS)
|
|
||||||
|
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
disable_progress_bars()
|
disable_progress_bars()
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
NetworkType = Literal["inversion", "lora"]
|
NetworkType = Literal["control", "inversion", "lora"]
|
||||||
|
|
||||||
|
|
||||||
class NetworkModel:
|
class NetworkModel:
|
||||||
name: str
|
name: str
|
||||||
|
tokens: List[str]
|
||||||
type: NetworkType
|
type: NetworkType
|
||||||
|
|
||||||
def __init__(self, name: str, type: NetworkType) -> None:
|
def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.tokens = tokens or []
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
|
"tokens": self.tokens,
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ def json_params(
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
parent: Dict = None,
|
parent: Optional[Dict] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
"input_size": size.tojson(),
|
"input_size": size.tojson(),
|
||||||
|
@ -158,6 +158,7 @@ def make_output_name(
|
||||||
size: Size,
|
size: Size,
|
||||||
extras: Optional[List[Optional[Param]]] = None,
|
extras: Optional[List[Optional[Param]]] = None,
|
||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
count = count or params.batch
|
count = count or params.batch
|
||||||
now = int(time())
|
now = int(time())
|
||||||
|
@ -183,7 +184,7 @@ def make_output_name(
|
||||||
|
|
||||||
return [
|
return [
|
||||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||||
for i in range(count)
|
for i in range(offset, count + offset)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
unlimited = 0
|
micro = 64
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
half = 256 # half tile for outpainting
|
half = 256 # half tile for outpainting
|
||||||
auto = 512 # auto tile size
|
auto = 512 # auto tile size
|
||||||
|
@ -25,6 +25,7 @@ class SizeChart(IntEnum):
|
||||||
hd16k = 2**14
|
hd16k = 2**14
|
||||||
hd32k = 2**15
|
hd32k = 2**15
|
||||||
hd64k = 2**16
|
hd64k = 2**16
|
||||||
|
max = 2**32 # should be a reasonable upper limit for now
|
||||||
|
|
||||||
|
|
||||||
class TileOrder:
|
class TileOrder:
|
||||||
|
@ -140,7 +141,7 @@ class DeviceParams:
|
||||||
if self.options is None:
|
if self.options is None:
|
||||||
return self.provider
|
return self.provider
|
||||||
else:
|
else:
|
||||||
return self.provider # (self.provider, self.options)
|
return (self.provider, self.options)
|
||||||
|
|
||||||
def sess_options(self, cache=True) -> SessionOptions:
|
def sess_options(self, cache=True) -> SessionOptions:
|
||||||
if cache and self.sess_options_cache is not None:
|
if cache and self.sess_options_cache is not None:
|
||||||
|
@ -201,11 +202,14 @@ class ImageParams:
|
||||||
batch: int
|
batch: int
|
||||||
control: Optional[NetworkModel]
|
control: Optional[NetworkModel]
|
||||||
input_prompt: str
|
input_prompt: str
|
||||||
input_negative_prompt: str
|
input_negative_prompt: Optional[str]
|
||||||
loopback: int
|
loopback: int
|
||||||
tiled_vae: bool
|
tiled_vae: bool
|
||||||
tiles: int
|
unet_tile: int
|
||||||
overlap: float
|
unet_overlap: float
|
||||||
|
vae_tile: int
|
||||||
|
vae_overlap: float
|
||||||
|
denoise: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -224,9 +228,11 @@ class ImageParams:
|
||||||
input_negative_prompt: Optional[str] = None,
|
input_negative_prompt: Optional[str] = None,
|
||||||
loopback: int = 0,
|
loopback: int = 0,
|
||||||
tiled_vae: bool = False,
|
tiled_vae: bool = False,
|
||||||
tiles: int = 512,
|
unet_overlap: float = 0.25,
|
||||||
overlap: float = 0.25,
|
unet_tile: int = 512,
|
||||||
stride: int = 64,
|
vae_overlap: float = 0.25,
|
||||||
|
vae_tile: int = 512,
|
||||||
|
denoise: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
@ -243,14 +249,16 @@ class ImageParams:
|
||||||
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
||||||
self.loopback = loopback
|
self.loopback = loopback
|
||||||
self.tiled_vae = tiled_vae
|
self.tiled_vae = tiled_vae
|
||||||
self.tiles = tiles
|
self.unet_overlap = unet_overlap
|
||||||
self.overlap = overlap
|
self.unet_tile = unet_tile
|
||||||
self.stride = stride
|
self.vae_overlap = vae_overlap
|
||||||
|
self.vae_tile = vae_tile
|
||||||
|
self.denoise = denoise
|
||||||
|
|
||||||
def do_cfg(self):
|
def do_cfg(self):
|
||||||
return self.cfg > 1.0
|
return self.cfg > 1.0
|
||||||
|
|
||||||
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
|
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
|
||||||
pipeline = pipeline or self.pipeline
|
pipeline = pipeline or self.pipeline
|
||||||
|
|
||||||
# if the correct pipeline was already requested, simply use that
|
# if the correct pipeline was already requested, simply use that
|
||||||
|
@ -259,7 +267,14 @@ class ImageParams:
|
||||||
|
|
||||||
# otherwise, check for additional allowed pipelines
|
# otherwise, check for additional allowed pipelines
|
||||||
if group == "img2img":
|
if group == "img2img":
|
||||||
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]:
|
if pipeline in [
|
||||||
|
"controlnet",
|
||||||
|
"img2img-sdxl",
|
||||||
|
"lpw",
|
||||||
|
"panorama",
|
||||||
|
"panorama-sdxl",
|
||||||
|
"pix2pix",
|
||||||
|
]:
|
||||||
return pipeline
|
return pipeline
|
||||||
elif pipeline == "txt2img-sdxl":
|
elif pipeline == "txt2img-sdxl":
|
||||||
return "img2img-sdxl"
|
return "img2img-sdxl"
|
||||||
|
@ -267,7 +282,7 @@ class ImageParams:
|
||||||
if pipeline in ["controlnet", "lpw", "panorama"]:
|
if pipeline in ["controlnet", "lpw", "panorama"]:
|
||||||
return pipeline
|
return pipeline
|
||||||
elif group == "txt2img":
|
elif group == "txt2img":
|
||||||
if pipeline in ["lpw", "panorama", "txt2img-sdxl"]:
|
if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]:
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
||||||
|
@ -280,7 +295,7 @@ class ImageParams:
|
||||||
return self.pipeline == "lpw"
|
return self.pipeline == "lpw"
|
||||||
|
|
||||||
def is_panorama(self):
|
def is_panorama(self):
|
||||||
return self.pipeline == "panorama"
|
return self.pipeline in ["panorama", "panorama-sdxl"]
|
||||||
|
|
||||||
def is_pix2pix(self):
|
def is_pix2pix(self):
|
||||||
return self.pipeline == "pix2pix"
|
return self.pipeline == "pix2pix"
|
||||||
|
@ -305,9 +320,11 @@ class ImageParams:
|
||||||
"input_negative_prompt": self.input_negative_prompt,
|
"input_negative_prompt": self.input_negative_prompt,
|
||||||
"loopback": self.loopback,
|
"loopback": self.loopback,
|
||||||
"tiled_vae": self.tiled_vae,
|
"tiled_vae": self.tiled_vae,
|
||||||
"tiles": self.tiles,
|
"unet_overlap": self.unet_overlap,
|
||||||
"overlap": self.overlap,
|
"unet_tile": self.unet_tile,
|
||||||
"stride": self.stride,
|
"vae_overlap": self.vae_overlap,
|
||||||
|
"vae_tile": self.vae_tile,
|
||||||
|
"denoise": self.denoise,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -327,9 +344,11 @@ class ImageParams:
|
||||||
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
||||||
kwargs.get("loopback", self.loopback),
|
kwargs.get("loopback", self.loopback),
|
||||||
kwargs.get("tiled_vae", self.tiled_vae),
|
kwargs.get("tiled_vae", self.tiled_vae),
|
||||||
kwargs.get("tiles", self.tiles),
|
kwargs.get("unet_overlap", self.unet_overlap),
|
||||||
kwargs.get("overlap", self.overlap),
|
kwargs.get("unet_tile", self.unet_tile),
|
||||||
kwargs.get("stride", self.stride),
|
kwargs.get("vae_overlap", self.vae_overlap),
|
||||||
|
kwargs.get("vae_tile", self.vae_tile),
|
||||||
|
kwargs.get("denoise", self.denoise),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -351,6 +370,17 @@ class StageParams:
|
||||||
self.tile_order = tile_order
|
self.tile_order = tile_order
|
||||||
self.tile_size = tile_size
|
self.tile_size = tile_size
|
||||||
|
|
||||||
|
def with_args(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return StageParams(
|
||||||
|
name=kwargs.get("name", self.name),
|
||||||
|
outscale=kwargs.get("outscale", self.outscale),
|
||||||
|
tile_order=kwargs.get("tile_order", self.tile_order),
|
||||||
|
tile_size=kwargs.get("tile_size", self.tile_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UpscaleParams:
|
class UpscaleParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -459,10 +489,14 @@ class HighresParams:
|
||||||
self.method = method
|
self.method = method
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
|
|
||||||
|
def outscale(self) -> int:
|
||||||
|
return self.scale**self.iterations
|
||||||
|
|
||||||
def resize(self, size: Size) -> Size:
|
def resize(self, size: Size) -> Size:
|
||||||
|
outscale = self.outscale()
|
||||||
return Size(
|
return Size(
|
||||||
size.width * (self.scale**self.iterations),
|
size.width * outscale,
|
||||||
size.height * (self.scale**self.iterations),
|
size.height * outscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from flask import Flask, jsonify, make_response, request, url_for
|
from flask import Flask, jsonify, make_response, request, url_for
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..chain import CHAIN_STAGES, ChainPipeline
|
from ..chain import CHAIN_STAGES, ChainPipeline
|
||||||
|
from ..chain.result import StageResult
|
||||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||||
from ..diffusers.run import (
|
from ..diffusers.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
@ -17,7 +19,7 @@ from ..diffusers.run import (
|
||||||
)
|
)
|
||||||
from ..diffusers.utils import replace_wildcards
|
from ..diffusers.utils import replace_wildcards
|
||||||
from ..output import json_params, make_output_name
|
from ..output import json_params, make_output_name
|
||||||
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
from ..params import Size, StageParams, TileOrder
|
||||||
from ..transformers.run import run_txt2txt_pipeline
|
from ..transformers.run import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -49,10 +51,11 @@ from .load import (
|
||||||
get_wildcard_data,
|
get_wildcard_data,
|
||||||
)
|
)
|
||||||
from .params import (
|
from .params import (
|
||||||
border_from_request,
|
build_border,
|
||||||
highres_from_request,
|
build_highres,
|
||||||
|
build_upscale,
|
||||||
|
pipeline_from_json,
|
||||||
pipeline_from_request,
|
pipeline_from_request,
|
||||||
upscale_from_request,
|
|
||||||
)
|
)
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
@ -167,8 +170,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
size = Size(source.width, source.height)
|
size = Size(source.width, source.height)
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "img2img")
|
device, params, _size = pipeline_from_request(server, "img2img")
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
source_filter = get_from_list(
|
source_filter = get_from_list(
|
||||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||||
)
|
)
|
||||||
|
@ -216,12 +219,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server, "txt2img")
|
device, params, size = pipeline_from_request(server, "txt2img")
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output = make_output_name(server, "txt2img", params, size)
|
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
||||||
|
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
pool.submit(
|
pool.submit(
|
||||||
|
@ -250,7 +253,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
if mask_file is None:
|
if mask_file is None:
|
||||||
return error_reply("mask image is required")
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||||
size = Size(source.width, source.height)
|
size = Size(source.width, source.height)
|
||||||
|
|
||||||
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||||
|
@ -270,9 +273,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||||
expand = border_from_request()
|
expand = build_border()
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||||
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||||
|
@ -340,8 +343,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
|
@ -366,47 +369,70 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||||
|
|
||||||
|
|
||||||
|
# keys that are specially parsed by params and should not show up in with_args
|
||||||
|
CHAIN_POP_KEYS = ["model", "control"]
|
||||||
|
|
||||||
|
|
||||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.debug(
|
if request.is_json:
|
||||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
logger.debug("chain pipeline request with JSON body")
|
||||||
)
|
data = request.get_json()
|
||||||
body = request.form.get("chain") or request.files.get("chain")
|
else:
|
||||||
if body is None:
|
logger.debug(
|
||||||
return error_reply("chain pipeline must have a body")
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
body = request.form.get("chain") or request.files.get("chain")
|
||||||
|
if body is None:
|
||||||
|
return error_reply("chain pipeline must have a body")
|
||||||
|
|
||||||
|
data = load_config_str(body)
|
||||||
|
|
||||||
data = load_config_str(body)
|
|
||||||
schema = load_config("./schemas/chain.yaml")
|
schema = load_config("./schemas/chain.yaml")
|
||||||
|
|
||||||
logger.debug("validating chain request: %s against %s", data, schema)
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
device, base_params, base_size = pipeline_from_json(
|
||||||
device, params, size = pipeline_from_request(server)
|
server, data=data.get("defaults")
|
||||||
output = make_output_name(server, "chain", params, size)
|
)
|
||||||
job_name = output[0]
|
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
|
||||||
|
|
||||||
|
# start building the pipeline
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get("stages", []):
|
for stage_data in data.get("stages", []):
|
||||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||||
kwargs = stage_data.get("params", {})
|
kwargs: Dict[str, Any] = stage_data.get("params", {})
|
||||||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||||
|
|
||||||
|
# TODO: combine base params with stage params
|
||||||
|
_device, params, size = pipeline_from_json(server, data=kwargs)
|
||||||
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
|
# remove parsed keys, like model names (which become paths)
|
||||||
|
for pop_key in CHAIN_POP_KEYS:
|
||||||
|
if pop_key in kwargs:
|
||||||
|
kwargs.pop(pop_key)
|
||||||
|
|
||||||
|
if "seed" in kwargs and kwargs["seed"] == -1:
|
||||||
|
kwargs.pop("seed")
|
||||||
|
|
||||||
|
# replace kwargs with parsed versions
|
||||||
|
kwargs["params"] = params
|
||||||
|
kwargs["size"] = size
|
||||||
|
|
||||||
|
border = build_border(kwargs)
|
||||||
|
kwargs["border"] = border
|
||||||
|
|
||||||
|
upscale = build_upscale(kwargs)
|
||||||
|
kwargs["upscale"] = upscale
|
||||||
|
|
||||||
|
# prepare the stage metadata
|
||||||
stage = StageParams(
|
stage = StageParams(
|
||||||
stage_data.get("name", stage_class.__name__),
|
stage_data.get("name", stage_class.__name__),
|
||||||
tile_size=get_size(kwargs.get("tile_size")),
|
tile_size=get_size(kwargs.get("tiles")),
|
||||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "border" in kwargs:
|
# load any images related to this stage
|
||||||
border = Border.even(int(kwargs.get("border")))
|
|
||||||
kwargs["border"] = border
|
|
||||||
|
|
||||||
if "upscale" in kwargs:
|
|
||||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
|
||||||
kwargs["upscale"] = upscale
|
|
||||||
|
|
||||||
stage_source_name = "source:%s" % (stage.name)
|
stage_source_name = "source:%s" % (stage.name)
|
||||||
stage_mask_name = "mask:%s" % (stage.name)
|
stage_mask_name = "mask:%s" % (stage.name)
|
||||||
|
|
||||||
|
@ -436,20 +462,25 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
|
output = make_output_name(
|
||||||
|
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
|
||||||
|
)
|
||||||
|
job_name = output[0]
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
empty_source = Image.new("RGB", (size.width, size.height))
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
pipeline,
|
pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
base_params,
|
||||||
empty_source,
|
StageResult.empty(),
|
||||||
output=output[0],
|
output=output,
|
||||||
size=size,
|
size=base_size,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||||
|
return jsonify(json_params(output, step_params, base_size))
|
||||||
|
|
||||||
|
|
||||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -471,7 +502,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
|
|
||||||
output = make_output_name(server, "upscale", params, size)
|
output = make_output_name(server, "upscale", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
|
|
@ -5,18 +5,44 @@ from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import get_boolean
|
from ..utils import get_boolean, get_list
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_ANY_PLATFORM = True
|
||||||
DEFAULT_CACHE_LIMIT = 5
|
DEFAULT_CACHE_LIMIT = 5
|
||||||
DEFAULT_JOB_LIMIT = 10
|
DEFAULT_JOB_LIMIT = 10
|
||||||
DEFAULT_IMAGE_FORMAT = "png"
|
DEFAULT_IMAGE_FORMAT = "png"
|
||||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||||
|
DEFAULT_SHOW_PROGRESS = True
|
||||||
|
DEFAULT_WORKER_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
class ServerContext:
|
class ServerContext:
|
||||||
|
bundle_path: str
|
||||||
|
model_path: str
|
||||||
|
output_path: str
|
||||||
|
params_path: str
|
||||||
|
cors_origin: str
|
||||||
|
any_platform: bool
|
||||||
|
block_platforms: List[str]
|
||||||
|
default_platform: str
|
||||||
|
image_format: str
|
||||||
|
cache_limit: int
|
||||||
|
cache_path: str
|
||||||
|
show_progress: bool
|
||||||
|
optimizations: List[str]
|
||||||
|
extra_models: List[str]
|
||||||
|
job_limit: int
|
||||||
|
memory_limit: int
|
||||||
|
admin_token: str
|
||||||
|
server_version: str
|
||||||
|
worker_retries: int
|
||||||
|
feature_flags: List[str]
|
||||||
|
plugins: List[str]
|
||||||
|
debug: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bundle_path: str = ".",
|
bundle_path: str = ".",
|
||||||
|
@ -24,19 +50,23 @@ class ServerContext:
|
||||||
output_path: str = ".",
|
output_path: str = ".",
|
||||||
params_path: str = ".",
|
params_path: str = ".",
|
||||||
cors_origin: str = "*",
|
cors_origin: str = "*",
|
||||||
any_platform: bool = True,
|
any_platform: bool = DEFAULT_ANY_PLATFORM,
|
||||||
block_platforms: Optional[List[str]] = None,
|
block_platforms: Optional[List[str]] = None,
|
||||||
default_platform: Optional[str] = None,
|
default_platform: Optional[str] = None,
|
||||||
image_format: str = DEFAULT_IMAGE_FORMAT,
|
image_format: str = DEFAULT_IMAGE_FORMAT,
|
||||||
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
show_progress: bool = True,
|
show_progress: bool = DEFAULT_SHOW_PROGRESS,
|
||||||
optimizations: Optional[List[str]] = None,
|
optimizations: Optional[List[str]] = None,
|
||||||
extra_models: Optional[List[str]] = None,
|
extra_models: Optional[List[str]] = None,
|
||||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||||
memory_limit: Optional[int] = None,
|
memory_limit: Optional[int] = None,
|
||||||
admin_token: Optional[str] = None,
|
admin_token: Optional[str] = None,
|
||||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||||
|
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||||
|
feature_flags: Optional[List[str]] = None,
|
||||||
|
plugins: Optional[List[str]] = None,
|
||||||
|
debug: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -56,6 +86,10 @@ class ServerContext:
|
||||||
self.memory_limit = memory_limit
|
self.memory_limit = memory_limit
|
||||||
self.admin_token = admin_token or token_urlsafe()
|
self.admin_token = admin_token or token_urlsafe()
|
||||||
self.server_version = server_version
|
self.server_version = server_version
|
||||||
|
self.worker_retries = worker_retries
|
||||||
|
self.feature_flags = feature_flags or []
|
||||||
|
self.plugins = plugins or []
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
|
@ -72,26 +106,41 @@ class ServerContext:
|
||||||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||||
# others
|
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
any_platform=get_boolean(
|
||||||
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
|
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
),
|
||||||
|
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
show_progress=get_boolean(
|
||||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
),
|
||||||
|
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
|
||||||
|
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
|
||||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
memory_limit=memory_limit,
|
memory_limit=memory_limit,
|
||||||
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||||
server_version=environ.get(
|
server_version=environ.get(
|
||||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
||||||
),
|
),
|
||||||
|
worker_retries=int(
|
||||||
|
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||||
|
),
|
||||||
|
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
|
||||||
|
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
|
||||||
|
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def has_feature(self, flag: str) -> bool:
|
||||||
|
return flag in self.feature_flags
|
||||||
|
|
||||||
|
def has_optimization(self, opt: str) -> bool:
|
||||||
|
return opt in self.optimizations
|
||||||
|
|
||||||
def torch_dtype(self):
|
def torch_dtype(self):
|
||||||
if "torch-fp16" in self.optimizations:
|
if self.has_optimization("torch-fp16"):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
else:
|
else:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
|
@ -134,25 +134,44 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
|
||||||
|
|
||||||
def apply_patch_basicsr(server: ServerContext):
|
def apply_patch_basicsr(server: ServerContext):
|
||||||
logger.debug("patching BasicSR module")
|
logger.debug("patching BasicSR module")
|
||||||
import basicsr.utils.download_util
|
try:
|
||||||
|
import basicsr.utils.download_util
|
||||||
|
|
||||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
basicsr.utils.download_util.load_file_from_url = partial(
|
||||||
|
patch_cache_path, server
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.info("unable to import basicsr utils for patching")
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("unable to patch basicsr utils")
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_codeformer(server: ServerContext):
|
def apply_patch_codeformer(server: ServerContext):
|
||||||
logger.debug("patching CodeFormer module")
|
logger.debug("patching CodeFormer module")
|
||||||
import codeformer.facelib.utils.misc
|
try:
|
||||||
|
import codeformer.facelib.utils.misc
|
||||||
|
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
||||||
|
patch_cache_path, server
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.info("unable to import codeformer utils for patching")
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("unable to patch codeformer utils")
|
||||||
|
|
||||||
|
|
||||||
def apply_patch_facexlib(server: ServerContext):
|
def apply_patch_facexlib(server: ServerContext):
|
||||||
logger.debug("patching Facexlib module")
|
logger.debug("patching Facexlib module")
|
||||||
import facexlib.utils
|
try:
|
||||||
|
import facexlib.utils
|
||||||
|
|
||||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
||||||
|
except ImportError:
|
||||||
|
logger.info("unable to import facexlib for patching")
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning("unable to patch facexlib utils")
|
||||||
|
|
||||||
|
|
||||||
def apply_patches(server: ServerContext):
|
def apply_patches(server: ServerContext):
|
||||||
|
|
|
@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
|
||||||
# Loaded from extra_models
|
# Loaded from extra_models
|
||||||
extra_hashes: Dict[str, str] = {}
|
extra_hashes: Dict[str, str] = {}
|
||||||
extra_strings: Dict[str, Any] = {}
|
extra_strings: Dict[str, Any] = {}
|
||||||
|
extra_tokens: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_config_params():
|
def get_config_params():
|
||||||
|
@ -160,9 +161,10 @@ def load_extras(server: ServerContext):
|
||||||
"""
|
"""
|
||||||
global extra_hashes
|
global extra_hashes
|
||||||
global extra_strings
|
global extra_strings
|
||||||
|
global extra_tokens
|
||||||
|
|
||||||
labels = {}
|
labels: Dict[str, str] = {}
|
||||||
strings = {}
|
strings: Dict[str, Any] = {}
|
||||||
|
|
||||||
extra_schema = load_config("./schemas/extras.yaml")
|
extra_schema = load_config("./schemas/extras.yaml")
|
||||||
|
|
||||||
|
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
|
||||||
else:
|
else:
|
||||||
labels[model_name] = model["label"]
|
labels[model_name] = model["label"]
|
||||||
|
|
||||||
|
if "tokens" in model:
|
||||||
|
logger.debug(
|
||||||
|
"collecting tokens for model %s from %s",
|
||||||
|
model_name,
|
||||||
|
file,
|
||||||
|
)
|
||||||
|
extra_tokens[model_name] = model["tokens"]
|
||||||
|
|
||||||
if "inversions" in model:
|
if "inversions" in model:
|
||||||
for inversion in model["inversions"]:
|
for inversion in model["inversions"]:
|
||||||
if "label" in inversion:
|
if "label" in inversion:
|
||||||
|
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
|
||||||
)
|
)
|
||||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||||
network_models.extend(
|
network_models.extend(
|
||||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
[
|
||||||
|
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
|
||||||
|
for model in inversion_models
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
lora_models = list_model_globs(
|
lora_models = list_model_globs(
|
||||||
|
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
|
||||||
base_path=path.join(server.model_path, "lora"),
|
base_path=path.join(server.model_path, "lora"),
|
||||||
)
|
)
|
||||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
network_models.extend(
|
||||||
|
[
|
||||||
|
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
|
||||||
|
for model in lora_models
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_params(server: ServerContext) -> None:
|
def load_params(server: ServerContext) -> None:
|
||||||
|
@ -397,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
|
||||||
):
|
):
|
||||||
if potential == "cuda" or potential == "rocm":
|
if potential == "cuda" or potential == "rocm":
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
options = {
|
options: Dict[str, Union[int, str]] = {
|
||||||
"device_id": i,
|
"device_id": i,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ class ModelCache:
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(len(cache)):
|
for i in range(len(cache)):
|
||||||
t, k, v = cache[i]
|
t, k, _v = cache[i]
|
||||||
if tag == t and key != k:
|
if tag == t and key != k:
|
||||||
logger.debug("updating model cache: %s %s", tag, key)
|
logger.debug("updating model cache: %s %s", tag, key)
|
||||||
cache[i] = (tag, key, value)
|
cache[i] = (tag, key, value)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||||
|
from ..diffusers.utils import random_seed
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
DeviceParams,
|
DeviceParams,
|
||||||
|
@ -34,143 +34,122 @@ from .utils import get_model_path
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(
|
def build_device(
|
||||||
server: ServerContext,
|
_server: ServerContext,
|
||||||
default_pipeline: str = "txt2img",
|
data: Dict[str, str],
|
||||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
) -> Optional[DeviceParams]:
|
||||||
user = request.remote_addr
|
|
||||||
|
|
||||||
# platform stuff
|
# platform stuff
|
||||||
device = None
|
device = None
|
||||||
device_name = request.args.get("platform")
|
device_name = data.get("platform")
|
||||||
|
|
||||||
if device_name is not None and device_name != "any":
|
if device_name is not None and device_name != "any":
|
||||||
for platform in get_available_platforms():
|
for platform in get_available_platforms():
|
||||||
if platform.device == device_name:
|
if platform.device == device_name:
|
||||||
device = platform
|
device = platform
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def build_params(
|
||||||
|
server: ServerContext,
|
||||||
|
default_pipeline: str,
|
||||||
|
data: Dict[str, str],
|
||||||
|
) -> ImageParams:
|
||||||
# diffusion model
|
# diffusion model
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
model = get_not_empty(data, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(server, model)
|
model_path = get_model_path(server, model)
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
control_name = request.args.get("control")
|
control_name = data.get("control")
|
||||||
for network in get_network_models():
|
for network in get_network_models():
|
||||||
if network.name == control_name:
|
if network.name == control_name:
|
||||||
control = network
|
control = network
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
pipeline = get_from_list(
|
pipeline = get_from_list(
|
||||||
request.args, "pipeline", get_available_pipelines(), default_pipeline
|
data, "pipeline", get_available_pipelines(), default_pipeline
|
||||||
)
|
)
|
||||||
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
|
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
|
||||||
|
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = get_config_value("scheduler")
|
scheduler = get_config_value("scheduler")
|
||||||
|
|
||||||
# prompt does not come from config
|
# prompt does not come from config
|
||||||
prompt = request.args.get("prompt", "")
|
prompt = data.get("prompt", "")
|
||||||
negative_prompt = request.args.get("negativePrompt", None)
|
negative_prompt = data.get("negativePrompt", None)
|
||||||
|
|
||||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
|
|
||||||
# image params
|
# image params
|
||||||
batch = get_and_clamp_int(
|
batch = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"batch",
|
"batch",
|
||||||
get_config_value("batch"),
|
get_config_value("batch"),
|
||||||
get_config_value("batch", "max"),
|
get_config_value("batch", "max"),
|
||||||
get_config_value("batch", "min"),
|
get_config_value("batch", "min"),
|
||||||
)
|
)
|
||||||
cfg = get_and_clamp_float(
|
cfg = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"cfg",
|
"cfg",
|
||||||
get_config_value("cfg"),
|
get_config_value("cfg"),
|
||||||
get_config_value("cfg", "max"),
|
get_config_value("cfg", "max"),
|
||||||
get_config_value("cfg", "min"),
|
get_config_value("cfg", "min"),
|
||||||
)
|
)
|
||||||
eta = get_and_clamp_float(
|
eta = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"eta",
|
"eta",
|
||||||
get_config_value("eta"),
|
get_config_value("eta"),
|
||||||
get_config_value("eta", "max"),
|
get_config_value("eta", "max"),
|
||||||
get_config_value("eta", "min"),
|
get_config_value("eta", "min"),
|
||||||
)
|
)
|
||||||
loopback = get_and_clamp_int(
|
loopback = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"loopback",
|
"loopback",
|
||||||
get_config_value("loopback"),
|
get_config_value("loopback"),
|
||||||
get_config_value("loopback", "max"),
|
get_config_value("loopback", "max"),
|
||||||
get_config_value("loopback", "min"),
|
get_config_value("loopback", "min"),
|
||||||
)
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"steps",
|
"steps",
|
||||||
get_config_value("steps"),
|
get_config_value("steps"),
|
||||||
get_config_value("steps", "max"),
|
get_config_value("steps", "max"),
|
||||||
get_config_value("steps", "min"),
|
get_config_value("steps", "min"),
|
||||||
)
|
)
|
||||||
height = get_and_clamp_int(
|
tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
|
||||||
request.args,
|
unet_overlap = get_and_clamp_float(
|
||||||
"height",
|
data,
|
||||||
get_config_value("height"),
|
"unet_overlap",
|
||||||
get_config_value("height", "max"),
|
get_config_value("unet_overlap"),
|
||||||
get_config_value("height", "min"),
|
get_config_value("unet_overlap", "max"),
|
||||||
|
get_config_value("unet_overlap", "min"),
|
||||||
)
|
)
|
||||||
width = get_and_clamp_int(
|
unet_tile = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"width",
|
"unet_tile",
|
||||||
get_config_value("width"),
|
get_config_value("unet_tile"),
|
||||||
get_config_value("width", "max"),
|
get_config_value("unet_tile", "max"),
|
||||||
get_config_value("width", "min"),
|
get_config_value("unet_tile", "min"),
|
||||||
)
|
)
|
||||||
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
|
vae_overlap = get_and_clamp_float(
|
||||||
tiles = get_and_clamp_int(
|
data,
|
||||||
request.args,
|
"vae_overlap",
|
||||||
"tiles",
|
get_config_value("vae_overlap"),
|
||||||
get_config_value("tiles"),
|
get_config_value("vae_overlap", "max"),
|
||||||
get_config_value("tiles", "max"),
|
get_config_value("vae_overlap", "min"),
|
||||||
get_config_value("tiles", "min"),
|
|
||||||
)
|
)
|
||||||
overlap = get_and_clamp_float(
|
vae_tile = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"overlap",
|
"vae_tile",
|
||||||
get_config_value("overlap"),
|
get_config_value("vae_tile"),
|
||||||
get_config_value("overlap", "max"),
|
get_config_value("vae_tile", "max"),
|
||||||
get_config_value("overlap", "min"),
|
get_config_value("vae_tile", "min"),
|
||||||
)
|
|
||||||
stride = get_and_clamp_int(
|
|
||||||
request.args,
|
|
||||||
"stride",
|
|
||||||
get_config_value("stride"),
|
|
||||||
get_config_value("stride", "max"),
|
|
||||||
get_config_value("stride", "min"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if stride > tiles:
|
seed = int(data.get("seed", -1))
|
||||||
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
|
|
||||||
stride = tiles
|
|
||||||
|
|
||||||
seed = int(request.args.get("seed", -1))
|
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
# this one can safely use np.random because it produces a single value
|
seed = random_seed()
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
|
||||||
user,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
model_path,
|
|
||||||
pipeline,
|
|
||||||
device or "any device",
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg,
|
|
||||||
seed,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = ImageParams(
|
params = ImageParams(
|
||||||
model_path,
|
model_path,
|
||||||
|
@ -186,38 +165,65 @@ def pipeline_from_request(
|
||||||
control=control,
|
control=control,
|
||||||
loopback=loopback,
|
loopback=loopback,
|
||||||
tiled_vae=tiled_vae,
|
tiled_vae=tiled_vae,
|
||||||
tiles=tiles,
|
unet_overlap=unet_overlap,
|
||||||
overlap=overlap,
|
unet_tile=unet_tile,
|
||||||
stride=stride,
|
vae_overlap=vae_overlap,
|
||||||
|
vae_tile=vae_tile,
|
||||||
)
|
)
|
||||||
size = Size(width, height)
|
|
||||||
return (device, params, size)
|
return params
|
||||||
|
|
||||||
|
|
||||||
def border_from_request() -> Border:
|
def build_size(
|
||||||
|
_server: ServerContext,
|
||||||
|
data: Dict[str, str],
|
||||||
|
) -> Size:
|
||||||
|
height = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"height",
|
||||||
|
get_config_value("height"),
|
||||||
|
get_config_value("height", "max"),
|
||||||
|
get_config_value("height", "min"),
|
||||||
|
)
|
||||||
|
width = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"width",
|
||||||
|
get_config_value("width"),
|
||||||
|
get_config_value("width", "max"),
|
||||||
|
get_config_value("width", "min"),
|
||||||
|
)
|
||||||
|
return Size(width, height)
|
||||||
|
|
||||||
|
|
||||||
|
def build_border(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> Border:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
left = get_and_clamp_int(
|
left = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"left",
|
"left",
|
||||||
get_config_value("left"),
|
get_config_value("left"),
|
||||||
get_config_value("left", "max"),
|
get_config_value("left", "max"),
|
||||||
get_config_value("left", "min"),
|
get_config_value("left", "min"),
|
||||||
)
|
)
|
||||||
right = get_and_clamp_int(
|
right = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"right",
|
"right",
|
||||||
get_config_value("right"),
|
get_config_value("right"),
|
||||||
get_config_value("right", "max"),
|
get_config_value("right", "max"),
|
||||||
get_config_value("right", "min"),
|
get_config_value("right", "min"),
|
||||||
)
|
)
|
||||||
top = get_and_clamp_int(
|
top = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"top",
|
"top",
|
||||||
get_config_value("top"),
|
get_config_value("top"),
|
||||||
get_config_value("top", "max"),
|
get_config_value("top", "max"),
|
||||||
get_config_value("top", "min"),
|
get_config_value("top", "min"),
|
||||||
)
|
)
|
||||||
bottom = get_and_clamp_int(
|
bottom = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"bottom",
|
"bottom",
|
||||||
get_config_value("bottom"),
|
get_config_value("bottom"),
|
||||||
get_config_value("bottom", "max"),
|
get_config_value("bottom", "max"),
|
||||||
|
@ -227,46 +233,51 @@ def border_from_request() -> Border:
|
||||||
return Border(left, right, top, bottom)
|
return Border(left, right, top, bottom)
|
||||||
|
|
||||||
|
|
||||||
def upscale_from_request() -> UpscaleParams:
|
def build_upscale(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> UpscaleParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
denoise = get_and_clamp_float(
|
denoise = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"denoise",
|
"denoise",
|
||||||
get_config_value("denoise"),
|
get_config_value("denoise"),
|
||||||
get_config_value("denoise", "max"),
|
get_config_value("denoise", "max"),
|
||||||
get_config_value("denoise", "min"),
|
get_config_value("denoise", "min"),
|
||||||
)
|
)
|
||||||
scale = get_and_clamp_int(
|
scale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"scale",
|
"scale",
|
||||||
get_config_value("scale"),
|
get_config_value("scale"),
|
||||||
get_config_value("scale", "max"),
|
get_config_value("scale", "max"),
|
||||||
get_config_value("scale", "min"),
|
get_config_value("scale", "min"),
|
||||||
)
|
)
|
||||||
outscale = get_and_clamp_int(
|
outscale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"outscale",
|
"outscale",
|
||||||
get_config_value("outscale"),
|
get_config_value("outscale"),
|
||||||
get_config_value("outscale", "max"),
|
get_config_value("outscale", "max"),
|
||||||
get_config_value("outscale", "min"),
|
get_config_value("outscale", "min"),
|
||||||
)
|
)
|
||||||
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
|
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
|
||||||
correction = get_from_list(request.args, "correction", get_correction_models())
|
correction = get_from_list(data, "correction", get_correction_models())
|
||||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
faces = get_not_empty(data, "faces", "false") == "true"
|
||||||
face_outscale = get_and_clamp_int(
|
face_outscale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"faceOutscale",
|
"faceOutscale",
|
||||||
get_config_value("faceOutscale"),
|
get_config_value("faceOutscale"),
|
||||||
get_config_value("faceOutscale", "max"),
|
get_config_value("faceOutscale", "max"),
|
||||||
get_config_value("faceOutscale", "min"),
|
get_config_value("faceOutscale", "min"),
|
||||||
)
|
)
|
||||||
face_strength = get_and_clamp_float(
|
face_strength = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"faceStrength",
|
"faceStrength",
|
||||||
get_config_value("faceStrength"),
|
get_config_value("faceStrength"),
|
||||||
get_config_value("faceStrength", "max"),
|
get_config_value("faceStrength", "max"),
|
||||||
get_config_value("faceStrength", "min"),
|
get_config_value("faceStrength", "min"),
|
||||||
)
|
)
|
||||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
upscale_order = data.get("upscaleOrder", "correction-first")
|
||||||
|
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscaling,
|
upscaling,
|
||||||
|
@ -282,37 +293,43 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def highres_from_request() -> HighresParams:
|
def build_highres(
|
||||||
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
|
data: Dict[str, str] = None,
|
||||||
|
) -> HighresParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
|
enabled = get_boolean(data, "highres", get_config_value("highres"))
|
||||||
iterations = get_and_clamp_int(
|
iterations = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresIterations",
|
"highresIterations",
|
||||||
get_config_value("highresIterations"),
|
get_config_value("highresIterations"),
|
||||||
get_config_value("highresIterations", "max"),
|
get_config_value("highresIterations", "max"),
|
||||||
get_config_value("highresIterations", "min"),
|
get_config_value("highresIterations", "min"),
|
||||||
)
|
)
|
||||||
method = get_from_list(request.args, "highresMethod", get_highres_methods())
|
method = get_from_list(data, "highresMethod", get_highres_methods())
|
||||||
scale = get_and_clamp_int(
|
scale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresScale",
|
"highresScale",
|
||||||
get_config_value("highresScale"),
|
get_config_value("highresScale"),
|
||||||
get_config_value("highresScale", "max"),
|
get_config_value("highresScale", "max"),
|
||||||
get_config_value("highresScale", "min"),
|
get_config_value("highresScale", "min"),
|
||||||
)
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresSteps",
|
"highresSteps",
|
||||||
get_config_value("highresSteps"),
|
get_config_value("highresSteps"),
|
||||||
get_config_value("highresSteps", "max"),
|
get_config_value("highresSteps", "max"),
|
||||||
get_config_value("highresSteps", "min"),
|
get_config_value("highresSteps", "min"),
|
||||||
)
|
)
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"highresStrength",
|
"highresStrength",
|
||||||
get_config_value("highresStrength"),
|
get_config_value("highresStrength"),
|
||||||
get_config_value("highresStrength", "max"),
|
get_config_value("highresStrength", "max"),
|
||||||
get_config_value("highresStrength", "min"),
|
get_config_value("highresStrength", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return HighresParams(
|
return HighresParams(
|
||||||
enabled,
|
enabled,
|
||||||
scale,
|
scale,
|
||||||
|
@ -321,3 +338,50 @@ def highres_from_request() -> HighresParams:
|
||||||
method=method,
|
method=method,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_json(
|
||||||
|
server: ServerContext,
|
||||||
|
data: Dict[str, str],
|
||||||
|
default_pipeline: str = "txt2img",
|
||||||
|
) -> PipelineParams:
|
||||||
|
"""
|
||||||
|
Like pipeline_from_request but expects a nested structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = build_device(server, data.get("device", data))
|
||||||
|
params = build_params(server, default_pipeline, data.get("params", data))
|
||||||
|
size = build_size(server, data.get("params", data))
|
||||||
|
|
||||||
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_request(
|
||||||
|
server: ServerContext,
|
||||||
|
default_pipeline: str = "txt2img",
|
||||||
|
) -> PipelineParams:
|
||||||
|
user = request.remote_addr
|
||||||
|
|
||||||
|
device = build_device(server, request.args)
|
||||||
|
params = build_params(server, default_pipeline, request.args)
|
||||||
|
size = build_size(server, request.args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||||
|
user,
|
||||||
|
params.steps,
|
||||||
|
params.scheduler,
|
||||||
|
params.model,
|
||||||
|
params.pipeline,
|
||||||
|
device or "any device",
|
||||||
|
size.width,
|
||||||
|
size.height,
|
||||||
|
params.cfg,
|
||||||
|
params.seed,
|
||||||
|
params.prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (device, params, size)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
from importlib import import_module
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
from onnx_web.chain.stages import add_stage
|
||||||
|
from onnx_web.diffusers.load import add_pipeline
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginExports:
|
||||||
|
pipelines: Dict[str, Any]
|
||||||
|
stages: Dict[str, Any]
|
||||||
|
|
||||||
|
def __init__(self, pipelines=None, stages=None) -> None:
|
||||||
|
self.pipelines = pipelines or {}
|
||||||
|
self.stages = stages or {}
|
||||||
|
|
||||||
|
|
||||||
|
PluginModule = Callable[[ServerContext], PluginExports]
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugins(server: ServerContext) -> PluginExports:
|
||||||
|
combined_exports = PluginExports()
|
||||||
|
|
||||||
|
for plugin in server.plugins:
|
||||||
|
logger.info("loading plugin module: %s", plugin)
|
||||||
|
try:
|
||||||
|
module: PluginModule = import_module(plugin)
|
||||||
|
exports = module(server)
|
||||||
|
|
||||||
|
for name, pipeline in exports.pipelines.items():
|
||||||
|
if name in combined_exports.pipelines:
|
||||||
|
logger.warning(
|
||||||
|
"multiple plugins exported a pipeline named %s", name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined_exports.pipelines[name] = pipeline
|
||||||
|
|
||||||
|
for name, stage in exports.stages.items():
|
||||||
|
if name in combined_exports.stages:
|
||||||
|
logger.warning("multiple plugins exported a stage named %s", name)
|
||||||
|
else:
|
||||||
|
combined_exports.stages[name] = stage
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error importing plugin")
|
||||||
|
|
||||||
|
return combined_exports
|
||||||
|
|
||||||
|
|
||||||
|
def register_plugins(exports: PluginExports) -> bool:
|
||||||
|
success = True
|
||||||
|
|
||||||
|
for name, pipeline in exports.pipelines.items():
|
||||||
|
success = success and add_pipeline(name, pipeline)
|
||||||
|
|
||||||
|
for name, stage in exports.stages.items():
|
||||||
|
success = success and add_stage(name, stage)
|
||||||
|
|
||||||
|
return success
|
|
@ -18,6 +18,11 @@ logger = getLogger(__name__)
|
||||||
SAFE_CHARS = "._-"
|
SAFE_CHARS = "._-"
|
||||||
|
|
||||||
|
|
||||||
|
def split_list(val: str) -> List[str]:
|
||||||
|
parts = [part.strip() for part in val.split(",")]
|
||||||
|
return [part for part in parts if len(part) > 0]
|
||||||
|
|
||||||
|
|
||||||
def base_join(base: str, tail: str) -> str:
|
def base_join(base: str, tail: str) -> str:
|
||||||
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||||
return path.join(base, tail_path)
|
return path.join(base, tail_path)
|
||||||
|
@ -28,7 +33,16 @@ def is_debug() -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
||||||
return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes")
|
val = args.get(key, str(default_value))
|
||||||
|
|
||||||
|
if isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
|
||||||
|
return val.lower() in ("1", "t", "true", "y", "yes")
|
||||||
|
|
||||||
|
|
||||||
|
def get_list(args: Any, key: str, default="") -> List[str]:
|
||||||
|
return split_list(args.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(
|
def get_and_clamp_float(
|
||||||
|
@ -61,13 +75,13 @@ def get_from_list(
|
||||||
|
|
||||||
|
|
||||||
def get_from_map(
|
def get_from_map(
|
||||||
args: Any, key: str, values: Dict[str, TElem], default: TElem
|
args: Any, key: str, values: Dict[str, TElem], default_key: str
|
||||||
) -> TElem:
|
) -> TElem:
|
||||||
selected = args.get(key, default)
|
selected = args.get(key, default_key)
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return values[selected]
|
return values[selected]
|
||||||
else:
|
else:
|
||||||
return values[default]
|
return values[default_key]
|
||||||
|
|
||||||
|
|
||||||
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||||
|
@ -195,6 +209,8 @@ def load_config(file: str) -> Dict:
|
||||||
return load_yaml(file)
|
return load_yaml(file)
|
||||||
elif ext in [".json"]:
|
elif ext in [".json"]:
|
||||||
return load_json(file)
|
return load_json(file)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown config file extension")
|
||||||
|
|
||||||
|
|
||||||
def load_config_str(raw: str) -> Dict:
|
def load_config_str(raw: str) -> Dict:
|
||||||
|
|
|
@ -25,6 +25,7 @@ class WorkerContext:
|
||||||
idle: "Value[bool]"
|
idle: "Value[bool]"
|
||||||
timeout: float
|
timeout: float
|
||||||
retries: int
|
retries: int
|
||||||
|
initial_retries: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -36,6 +37,8 @@ class WorkerContext:
|
||||||
progress: "Queue[ProgressCommand]",
|
progress: "Queue[ProgressCommand]",
|
||||||
active_pid: "Value[int]",
|
active_pid: "Value[int]",
|
||||||
idle: "Value[bool]",
|
idle: "Value[bool]",
|
||||||
|
retries: int,
|
||||||
|
timeout: float,
|
||||||
):
|
):
|
||||||
self.job = None
|
self.job = None
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -47,12 +50,13 @@ class WorkerContext:
|
||||||
self.active_pid = active_pid
|
self.active_pid = active_pid
|
||||||
self.last_progress = None
|
self.last_progress = None
|
||||||
self.idle = idle
|
self.idle = idle
|
||||||
self.timeout = 1.0
|
self.initial_retries = retries
|
||||||
self.retries = 3 # TODO: get from env
|
self.retries = retries
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
def start(self, job: str) -> None:
|
def start(self, job: str) -> None:
|
||||||
self.job = job
|
self.job = job
|
||||||
self.retries = 3
|
self.retries = self.initial_retries
|
||||||
self.set_cancel(cancel=False)
|
self.set_cancel(cancel=False)
|
||||||
self.set_idle(idle=False)
|
self.set_idle(idle=False)
|
||||||
|
|
||||||
|
@ -82,7 +86,7 @@ class WorkerContext:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_progress_callback(self) -> ProgressCallback:
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
from ..chain.base import ChainProgress
|
from ..chain.pipeline import ChainProgress
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
on_progress.step = step
|
on_progress.step = step
|
||||||
|
|
|
@ -86,15 +86,15 @@ class DevicePoolExecutor:
|
||||||
self.logs = Queue(self.max_pending_per_worker)
|
self.logs = Queue(self.max_pending_per_worker)
|
||||||
self.rlock = Lock()
|
self.rlock = Lock()
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self, *args) -> None:
|
||||||
self.create_health_worker()
|
self.create_health_worker()
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
self.create_progress_worker()
|
self.create_progress_worker()
|
||||||
|
|
||||||
for device in self.devices:
|
for device in self.devices:
|
||||||
self.create_device_worker(device)
|
self.create_device_worker(device, *args)
|
||||||
|
|
||||||
def create_device_worker(self, device: DeviceParams) -> None:
|
def create_device_worker(self, device: DeviceParams, *args) -> None:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
|
||||||
# always recreate queues
|
# always recreate queues
|
||||||
|
@ -124,15 +124,17 @@ class DevicePoolExecutor:
|
||||||
pending=self.pending[name],
|
pending=self.pending[name],
|
||||||
active_pid=current,
|
active_pid=current,
|
||||||
idle=self.worker_idle[name],
|
idle=self.worker_idle[name],
|
||||||
|
retries=self.server.worker_retries,
|
||||||
|
timeout=self.progress_interval,
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
|
|
||||||
worker = Process(
|
worker = Process(
|
||||||
name=f"onnx-web worker: {name}",
|
name=f"onnx-web worker: {name}",
|
||||||
target=worker_main,
|
target=worker_main,
|
||||||
args=(context, self.server),
|
args=(context, self.server, *args),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
worker.daemon = True
|
|
||||||
self.workers[name] = worker
|
self.workers[name] = worker
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
|
|
|
@ -27,10 +27,14 @@ MEMORY_ERRORS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def worker_main(worker: WorkerContext, server: ServerContext):
|
def worker_main(
|
||||||
apply_patches(server)
|
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
|
||||||
|
):
|
||||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||||
|
|
||||||
|
if patch:
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"checking in from worker with providers: %s", get_available_providers()
|
"checking in from worker with providers: %s", get_available_providers()
|
||||||
)
|
)
|
||||||
|
@ -46,7 +50,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
||||||
getpid(),
|
getpid(),
|
||||||
worker.get_active(),
|
worker.get_active(),
|
||||||
)
|
)
|
||||||
exit(EXIT_REPLACED)
|
return exit(EXIT_REPLACED)
|
||||||
|
|
||||||
# wait briefly for the next job
|
# wait briefly for the next job
|
||||||
job = worker.pending.get(timeout=worker.timeout)
|
job = worker.pending.get(timeout=worker.timeout)
|
||||||
|
@ -69,15 +73,15 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.debug("worker got keyboard interrupt")
|
logger.debug("worker got keyboard interrupt")
|
||||||
worker.fail()
|
worker.fail()
|
||||||
exit(EXIT_INTERRUPT)
|
return exit(EXIT_INTERRUPT)
|
||||||
except RetryException:
|
except RetryException:
|
||||||
logger.exception("retry error in worker, exiting")
|
logger.exception("retry error in worker, exiting")
|
||||||
worker.fail()
|
worker.fail()
|
||||||
exit(EXIT_ERROR)
|
return exit(EXIT_ERROR)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.exception("value error in worker, exiting")
|
logger.exception("value error in worker, exiting")
|
||||||
worker.fail()
|
worker.fail()
|
||||||
exit(EXIT_ERROR)
|
return exit(EXIT_ERROR)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e_str = str(e)
|
e_str = str(e)
|
||||||
# restart the worker on memory errors
|
# restart the worker on memory errors
|
||||||
|
@ -85,7 +89,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
||||||
if e_mem in e_str:
|
if e_mem in e_str:
|
||||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||||
worker.fail()
|
worker.fail()
|
||||||
exit(EXIT_MEMORY)
|
return exit(EXIT_MEMORY)
|
||||||
|
|
||||||
# carry on for other errors
|
# carry on for other errors
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
|
|
@ -98,7 +98,7 @@
|
||||||
"highresSteps": {
|
"highresSteps": {
|
||||||
"default": 0,
|
"default": 0,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
"max": 200,
|
"max": 500,
|
||||||
"step": 1
|
"step": 1
|
||||||
},
|
},
|
||||||
"highresStrength": {
|
"highresStrength": {
|
||||||
|
@ -141,12 +141,6 @@
|
||||||
"max": 4,
|
"max": 4,
|
||||||
"step": 1
|
"step": 1
|
||||||
},
|
},
|
||||||
"overlap": {
|
|
||||||
"default": 0.25,
|
|
||||||
"min": 0.0,
|
|
||||||
"max": 0.9,
|
|
||||||
"step": 0.01
|
|
||||||
},
|
|
||||||
"pipeline": {
|
"pipeline": {
|
||||||
"default": "",
|
"default": "",
|
||||||
"keys": [
|
"keys": [
|
||||||
|
@ -188,7 +182,7 @@
|
||||||
"steps": {
|
"steps": {
|
||||||
"default": 25,
|
"default": 25,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
"max": 200,
|
"max": 300,
|
||||||
"step": 1
|
"step": 1
|
||||||
},
|
},
|
||||||
"strength": {
|
"strength": {
|
||||||
|
@ -197,21 +191,9 @@
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 0.01
|
"step": 0.01
|
||||||
},
|
},
|
||||||
"stride": {
|
"tiled_vae": {
|
||||||
"default": 128,
|
|
||||||
"min": 64,
|
|
||||||
"max": 512,
|
|
||||||
"step": 64
|
|
||||||
},
|
|
||||||
"tiledVAE": {
|
|
||||||
"default": false
|
"default": false
|
||||||
},
|
},
|
||||||
"tiles": {
|
|
||||||
"default": 512,
|
|
||||||
"min": 128,
|
|
||||||
"max": 2048,
|
|
||||||
"step": 128
|
|
||||||
},
|
|
||||||
"tileOrder": {
|
"tileOrder": {
|
||||||
"default": "spiral",
|
"default": "spiral",
|
||||||
"keys": [
|
"keys": [
|
||||||
|
@ -225,6 +207,18 @@
|
||||||
"max": 1024,
|
"max": 1024,
|
||||||
"step": 8
|
"step": 8
|
||||||
},
|
},
|
||||||
|
"unet_overlap": {
|
||||||
|
"default": 0.25,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 0.9,
|
||||||
|
"step": 0.01
|
||||||
|
},
|
||||||
|
"unet_tile": {
|
||||||
|
"default": 512,
|
||||||
|
"min": 128,
|
||||||
|
"max": 2048,
|
||||||
|
"step": 128
|
||||||
|
},
|
||||||
"upscaleOrder": {
|
"upscaleOrder": {
|
||||||
"default": "correction-first",
|
"default": "correction-first",
|
||||||
"keys": [
|
"keys": [
|
||||||
|
@ -237,6 +231,18 @@
|
||||||
"default": "",
|
"default": "",
|
||||||
"keys": []
|
"keys": []
|
||||||
},
|
},
|
||||||
|
"vae_overlap": {
|
||||||
|
"default": 0.25,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 0.9,
|
||||||
|
"step": 0.01
|
||||||
|
},
|
||||||
|
"vae_tile": {
|
||||||
|
"default": 512,
|
||||||
|
"min": 256,
|
||||||
|
"max": 1024,
|
||||||
|
"step": 128
|
||||||
|
},
|
||||||
"width": {
|
"width": {
|
||||||
"default": 512,
|
"default": 512,
|
||||||
"min": 128,
|
"min": 128,
|
||||||
|
|
|
@ -9,12 +9,14 @@ skip_glob = ["*/lpw.py"]
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
# ignore_missing_imports = true
|
# ignore_missing_imports = true
|
||||||
exclude = [
|
exclude = [
|
||||||
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
|
"onnx_web.diffusers.pipelines.controlnet",
|
||||||
|
"onnx_web.diffusers.pipelines.lpw",
|
||||||
|
"onnx_web.diffusers.pipelines.pix2pix"
|
||||||
]
|
]
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [
|
module = [
|
||||||
"arpeggio",
|
"arpeggio",
|
||||||
"basicsr.archs.rrdbnet_arch",
|
"basicsr.archs.rrdbnet_arch",
|
||||||
"basicsr.utils.download_util",
|
"basicsr.utils.download_util",
|
||||||
"basicsr.utils",
|
"basicsr.utils",
|
||||||
|
@ -27,8 +29,10 @@ module = [
|
||||||
"compel",
|
"compel",
|
||||||
"controlnet_aux",
|
"controlnet_aux",
|
||||||
"cv2",
|
"cv2",
|
||||||
|
"debugpy",
|
||||||
"diffusers",
|
"diffusers",
|
||||||
"diffusers.configuration_utils",
|
"diffusers.configuration_utils",
|
||||||
|
"diffusers.image_processor",
|
||||||
"diffusers.loaders",
|
"diffusers.loaders",
|
||||||
"diffusers.models.attention_processor",
|
"diffusers.models.attention_processor",
|
||||||
"diffusers.models.autoencoder_kl",
|
"diffusers.models.autoencoder_kl",
|
||||||
|
@ -41,9 +45,10 @@ module = [
|
||||||
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
||||||
"diffusers.pipelines.onnx_utils",
|
"diffusers.pipelines.onnx_utils",
|
||||||
"diffusers.pipelines.paint_by_example",
|
"diffusers.pipelines.paint_by_example",
|
||||||
|
"diffusers.pipelines.pipeline_utils",
|
||||||
"diffusers.pipelines.stable_diffusion",
|
"diffusers.pipelines.stable_diffusion",
|
||||||
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
||||||
"diffusers.pipeline_utils",
|
"diffusers.pipelines.stable_diffusion_xl",
|
||||||
"diffusers.schedulers",
|
"diffusers.schedulers",
|
||||||
"diffusers.utils.logging",
|
"diffusers.utils.logging",
|
||||||
"facexlib.utils",
|
"facexlib.utils",
|
||||||
|
@ -56,11 +61,17 @@ module = [
|
||||||
"mediapipe",
|
"mediapipe",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
"onnxruntime.transformers.float16",
|
"onnxruntime.transformers.float16",
|
||||||
|
"optimum.exporters.onnx",
|
||||||
|
"optimum.onnxruntime",
|
||||||
|
"optimum.onnxruntime.modeling_diffusion",
|
||||||
|
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
|
||||||
|
"optimum.pipelines.diffusers.pipeline_utils",
|
||||||
"piexif",
|
"piexif",
|
||||||
"piexif.helper",
|
"piexif.helper",
|
||||||
"realesrgan",
|
"realesrgan",
|
||||||
"realesrgan.archs.srvgg_arch",
|
"realesrgan.archs.srvgg_arch",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
|
"scipy",
|
||||||
"timm.models.layers",
|
"timm.models.layers",
|
||||||
"transformers",
|
"transformers",
|
||||||
"win10toast"
|
"win10toast"
|
||||||
|
|
|
@ -46,17 +46,31 @@ $defs:
|
||||||
patternProperties:
|
patternProperties:
|
||||||
"^[-_A-Za-z]+$":
|
"^[-_A-Za-z]+$":
|
||||||
oneOf:
|
oneOf:
|
||||||
|
- type: boolean
|
||||||
- type: number
|
- type: number
|
||||||
- type: string
|
- type: string
|
||||||
|
- type: "null"
|
||||||
|
|
||||||
request_chain:
|
request_chain:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/request_stage"
|
$ref: "#/$defs/request_stage"
|
||||||
|
|
||||||
|
request_defaults:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
txt2img:
|
||||||
|
$ref: "#/$defs/image_params"
|
||||||
|
img2img:
|
||||||
|
$ref: "#/$defs/image_params"
|
||||||
|
|
||||||
type: object
|
type: object
|
||||||
additionalProperties: False
|
additionalProperties: False
|
||||||
required: [stages]
|
required: [stages]
|
||||||
properties:
|
properties:
|
||||||
|
defaults:
|
||||||
|
$ref: "#/$defs/request_defaults"
|
||||||
|
platform:
|
||||||
|
type: string
|
||||||
stages:
|
stages:
|
||||||
$ref: "#/$defs/request_chain"
|
$ref: "#/$defs/request_chain"
|
||||||
|
|
|
@ -10,34 +10,53 @@ $defs:
|
||||||
- type: number
|
- type: number
|
||||||
- type: string
|
- type: string
|
||||||
|
|
||||||
lora_network:
|
tensor_format:
|
||||||
|
type: string
|
||||||
|
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||||
|
|
||||||
|
embedding_network:
|
||||||
type: object
|
type: object
|
||||||
required: [name, source]
|
required: [name, source]
|
||||||
properties:
|
properties:
|
||||||
name:
|
format:
|
||||||
type: string
|
$ref: "#/$defs/tensor_format"
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
label:
|
label:
|
||||||
type: string
|
type: string
|
||||||
weight:
|
model:
|
||||||
type: number
|
|
||||||
|
|
||||||
textual_inversion_network:
|
|
||||||
type: object
|
|
||||||
required: [name, source]
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
format:
|
|
||||||
type: string
|
type: string
|
||||||
enum: [concept, embeddings]
|
enum: [concept, embeddings]
|
||||||
label:
|
name:
|
||||||
|
type: string
|
||||||
|
source:
|
||||||
type: string
|
type: string
|
||||||
token:
|
token:
|
||||||
type: string
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: inversion # TODO: add embedding
|
||||||
|
weight:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
lora_network:
|
||||||
|
type: object
|
||||||
|
required: [name, source, type]
|
||||||
|
properties:
|
||||||
|
label:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
enum: [cloneofsimo, sd-scripts]
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
source:
|
||||||
|
type: string
|
||||||
|
tokens:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: lora
|
||||||
weight:
|
weight:
|
||||||
type: number
|
type: number
|
||||||
|
|
||||||
|
@ -46,8 +65,7 @@ $defs:
|
||||||
required: [name, source]
|
required: [name, source]
|
||||||
properties:
|
properties:
|
||||||
format:
|
format:
|
||||||
type: string
|
$ref: "#/$defs/tensor_format"
|
||||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
|
||||||
half:
|
half:
|
||||||
type: boolean
|
type: boolean
|
||||||
label:
|
label:
|
||||||
|
@ -85,7 +103,7 @@ $defs:
|
||||||
inversions:
|
inversions:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/textual_inversion_network"
|
$ref: "#/$defs/embedding_network"
|
||||||
loras:
|
loras:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
@ -100,6 +118,7 @@ $defs:
|
||||||
panorama,
|
panorama,
|
||||||
pix2pix,
|
pix2pix,
|
||||||
txt2img,
|
txt2img,
|
||||||
|
txt2img-sdxl,
|
||||||
upscale,
|
upscale,
|
||||||
]
|
]
|
||||||
vae:
|
vae:
|
||||||
|
@ -141,31 +160,6 @@ $defs:
|
||||||
source:
|
source:
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
source_network:
|
|
||||||
type: object
|
|
||||||
required: [name, source, type]
|
|
||||||
properties:
|
|
||||||
format:
|
|
||||||
type: string
|
|
||||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
|
||||||
model:
|
|
||||||
type: string
|
|
||||||
enum: [
|
|
||||||
# inversion
|
|
||||||
concept,
|
|
||||||
embeddings,
|
|
||||||
# lora
|
|
||||||
cloneofsimo,
|
|
||||||
sd-scripts
|
|
||||||
]
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
source:
|
|
||||||
type: string
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
enum: [inversion, lora]
|
|
||||||
|
|
||||||
translation:
|
translation:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties: False
|
additionalProperties: False
|
||||||
|
@ -193,7 +187,9 @@ properties:
|
||||||
networks:
|
networks:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: "#/$defs/source_network"
|
oneOf:
|
||||||
|
- $ref: "#/$defs/lora_network"
|
||||||
|
- $ref: "#/$defs/embedding_network"
|
||||||
sources:
|
sources:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||||
|
from os import path
|
||||||
|
from onnx.checker import check_model
|
||||||
|
from onnx.external_data_helper import (
|
||||||
|
convert_model_to_external_data,
|
||||||
|
write_external_data_tensors,
|
||||||
|
)
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from onnx_web.convert.utils import ConversionContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
context = ConversionContext.from_environ()
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--base", type=str)
|
||||||
|
parser.add_argument("--dest", type=str)
|
||||||
|
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||||
|
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
||||||
|
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger.info(
|
||||||
|
"merging %s with %s with weights: %s",
|
||||||
|
args.lora_models,
|
||||||
|
args.base,
|
||||||
|
args.lora_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
default_weight = 1.0 / len(args.lora_models)
|
||||||
|
while len(args.lora_weights) < len(args.lora_models):
|
||||||
|
args.lora_weights.append(default_weight)
|
||||||
|
|
||||||
|
blend_model = blend_loras(
|
||||||
|
context,
|
||||||
|
args.base,
|
||||||
|
list(zip(args.lora_models, args.lora_weights)),
|
||||||
|
args.type,
|
||||||
|
)
|
||||||
|
if args.dest is None or args.dest == "" or args.dest == ":load":
|
||||||
|
# convert to external data and save to memory
|
||||||
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
logger.info("saved external data for %s nodes", len(external_data))
|
||||||
|
|
||||||
|
external_names, external_values = zip(*external_data)
|
||||||
|
opts = SessionOptions()
|
||||||
|
opts.add_external_initializers(list(external_names), list(external_values))
|
||||||
|
sess = InferenceSession(
|
||||||
|
bare_model.SerializeToString(),
|
||||||
|
sess_options=opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
convert_model_to_external_data(
|
||||||
|
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
||||||
|
)
|
||||||
|
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||||
|
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||||
|
|
||||||
|
with open(dest_file, "w+b") as model_file:
|
||||||
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
|
logger.info("successfully saved blended model: %s", dest_file)
|
||||||
|
|
||||||
|
check_model(dest_file)
|
||||||
|
|
||||||
|
logger.info("checked blended model")
|
BIN
api/scripts/test-refs/blend-512-muffin-white-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/blend-512-muffin-white-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-even-256-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-even-256-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
|
@ -30,6 +30,10 @@ FAST_TEST = 10
|
||||||
SLOW_TEST = 25
|
SLOW_TEST = 25
|
||||||
VERY_SLOW_TEST = 75
|
VERY_SLOW_TEST = 75
|
||||||
|
|
||||||
|
STRICT_TEST = 1e-4
|
||||||
|
LOOSE_TEST = 1e-2
|
||||||
|
VERY_LOOSE_TEST = 0.025
|
||||||
|
|
||||||
|
|
||||||
def test_path(relpath: str) -> str:
|
def test_path(relpath: str) -> str:
|
||||||
return path.join(path.dirname(__file__), relpath)
|
return path.join(path.dirname(__file__), relpath)
|
||||||
|
@ -41,7 +45,7 @@ class TestCase:
|
||||||
name: str,
|
name: str,
|
||||||
query: str,
|
query: str,
|
||||||
max_attempts: int = FAST_TEST,
|
max_attempts: int = FAST_TEST,
|
||||||
mse_threshold: float = 1e-4,
|
mse_threshold: float = STRICT_TEST,
|
||||||
source: Union[Image.Image, List[Image.Image]] = None,
|
source: Union[Image.Image, List[Image.Image]] = None,
|
||||||
mask: Image.Image = None,
|
mask: Image.Image = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -65,6 +69,7 @@ TEST_DATA = [
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-deis",
|
"txt2img-sd-v1-5-512-muffin-deis",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
||||||
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-dpm",
|
"txt2img-sd-v1-5-512-muffin-dpm",
|
||||||
|
@ -73,10 +78,12 @@ TEST_DATA = [
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-heun",
|
"txt2img-sd-v1-5-512-muffin-heun",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
|
||||||
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-512-muffin-unipc",
|
"txt2img-sd-v1-5-512-muffin-unipc",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi",
|
||||||
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v2-1-512-muffin",
|
"txt2img-sd-v2-1-512-muffin",
|
||||||
|
@ -84,7 +91,7 @@ TEST_DATA = [
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v2-1-768-muffin",
|
"txt2img-sd-v2-1-768-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768",
|
||||||
max_attempts=SLOW_TEST,
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -106,7 +113,7 @@ TEST_DATA = [
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"img2img-sd-v1-5-256-pumpkin",
|
"img2img-sd-v1-5-256-pumpkin",
|
||||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none",
|
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256",
|
||||||
source="txt2img-sd-v1-5-256-muffin-0",
|
source="txt2img-sd-v1-5-256-muffin-0",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
|
@ -130,7 +137,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=SLOW_TEST,
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.025,
|
mse_threshold=VERY_LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"outpaint-vertical-512",
|
"outpaint-vertical-512",
|
||||||
|
@ -141,7 +148,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=SLOW_TEST,
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.010,
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"outpaint-horizontal-512",
|
"outpaint-horizontal-512",
|
||||||
|
@ -152,7 +159,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=SLOW_TEST,
|
max_attempts=SLOW_TEST,
|
||||||
mse_threshold=0.010,
|
mse_threshold=LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x2-1024-muffin",
|
"upscale-resrgan-x2-1024-muffin",
|
||||||
|
@ -229,7 +236,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=VERY_SLOW_TEST,
|
max_attempts=VERY_SLOW_TEST,
|
||||||
mse_threshold=0.025,
|
mse_threshold=VERY_LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"outpaint-panorama-vertical-512",
|
"outpaint-panorama-vertical-512",
|
||||||
|
@ -240,7 +247,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=VERY_SLOW_TEST,
|
max_attempts=VERY_SLOW_TEST,
|
||||||
mse_threshold=0.025,
|
mse_threshold=VERY_LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"outpaint-panorama-horizontal-512",
|
"outpaint-panorama-horizontal-512",
|
||||||
|
@ -251,7 +258,7 @@ TEST_DATA = [
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
mask="mask-black",
|
mask="mask-black",
|
||||||
max_attempts=VERY_SLOW_TEST,
|
max_attempts=VERY_SLOW_TEST,
|
||||||
mse_threshold=0.025,
|
mse_threshold=VERY_LOOSE_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x4-codeformer-2048-muffin",
|
"upscale-resrgan-x4-codeformer-2048-muffin",
|
||||||
|
@ -260,6 +267,7 @@ TEST_DATA = [
|
||||||
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x4-gfpgan-2048-muffin",
|
"upscale-resrgan-x4-gfpgan-2048-muffin",
|
||||||
|
@ -268,6 +276,7 @@ TEST_DATA = [
|
||||||
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-swinir-x4-codeformer-2048-muffin",
|
"upscale-swinir-x4-codeformer-2048-muffin",
|
||||||
|
@ -276,6 +285,7 @@ TEST_DATA = [
|
||||||
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-swinir-x4-gfpgan-2048-muffin",
|
"upscale-swinir-x4-gfpgan-2048-muffin",
|
||||||
|
@ -284,6 +294,7 @@ TEST_DATA = [
|
||||||
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-512-muffin-0",
|
source="txt2img-sd-v1-5-512-muffin-0",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-sd-x4-codeformer-2048-muffin",
|
"upscale-sd-x4-codeformer-2048-muffin",
|
||||||
|
@ -305,18 +316,18 @@ TEST_DATA = [
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-panorama-1024x768-muffin",
|
"txt2img-panorama-1024x768-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true",
|
||||||
max_attempts=VERY_SLOW_TEST,
|
max_attempts=VERY_SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"img2img-panorama-1024x768-pumpkin",
|
"img2img-panorama-1024x768-pumpkin",
|
||||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true",
|
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true",
|
||||||
source="txt2img-panorama-1024x768-muffin-0",
|
source="txt2img-panorama-1024x768-muffin-0",
|
||||||
max_attempts=VERY_SLOW_TEST,
|
max_attempts=VERY_SLOW_TEST,
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"txt2img-sd-v1-5-tall-muffin",
|
"txt2img-sd-v1-5-tall-muffin",
|
||||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768",
|
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768",
|
||||||
),
|
),
|
||||||
TestCase(
|
TestCase(
|
||||||
"upscale-resrgan-x4-tall-muffin",
|
"upscale-resrgan-x4-tall-muffin",
|
||||||
|
@ -325,6 +336,7 @@ TEST_DATA = [
|
||||||
"&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0"
|
"&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0"
|
||||||
),
|
),
|
||||||
source="txt2img-sd-v1-5-tall-muffin-0",
|
source="txt2img-sd-v1-5-tall-muffin-0",
|
||||||
|
max_attempts=SLOW_TEST,
|
||||||
),
|
),
|
||||||
# TODO: non-square controlnet
|
# TODO: non-square controlnet
|
||||||
]
|
]
|
||||||
|
@ -335,6 +347,39 @@ class TestError(Exception):
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestResult:
|
||||||
|
error: Optional[str]
|
||||||
|
mse: Optional[float]
|
||||||
|
name: str
|
||||||
|
passed: bool
|
||||||
|
|
||||||
|
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
|
||||||
|
self.error = error
|
||||||
|
self.mse = mse
|
||||||
|
self.name = name
|
||||||
|
self.passed = passed
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.passed:
|
||||||
|
if self.mse is not None:
|
||||||
|
return f"{self.name} ({self.mse})"
|
||||||
|
else:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
if self.mse is not None:
|
||||||
|
return f"{self.name}: {self.error} ({self.mse})"
|
||||||
|
else:
|
||||||
|
return f"{self.name}: {self.error}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def passed(self, name: str, mse = None):
|
||||||
|
return TestResult(name, mse=mse)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failed(self, name: str, error: str, mse = None):
|
||||||
|
return TestResult(name, error=error, mse=mse, passed=False)
|
||||||
|
|
||||||
|
|
||||||
def parse_args(args: List[str]):
|
def parse_args(args: List[str]):
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog="onnx-web release tests",
|
prog="onnx-web release tests",
|
||||||
|
@ -441,14 +486,14 @@ def run_test(
|
||||||
host: str,
|
host: str,
|
||||||
test: TestCase,
|
test: TestCase,
|
||||||
mse_mult: float = 1.0,
|
mse_mult: float = 1.0,
|
||||||
) -> bool:
|
) -> TestResult:
|
||||||
"""
|
"""
|
||||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keys = generate_images(host, test)
|
keys = generate_images(host, test)
|
||||||
if keys is None:
|
if keys is None:
|
||||||
raise ValueError("could not generate image")
|
return TestResult.failed(test.name, "could not generate image")
|
||||||
|
|
||||||
ready = False
|
ready = False
|
||||||
for attempt in tqdm(range(test.max_attempts)):
|
for attempt in tqdm(range(test.max_attempts)):
|
||||||
|
@ -461,13 +506,13 @@ def run_test(
|
||||||
sleep(6)
|
sleep(6)
|
||||||
|
|
||||||
if not ready:
|
if not ready:
|
||||||
raise ValueError("image was not ready in time")
|
return TestResult.failed(test.name, "image was not ready in time")
|
||||||
|
|
||||||
results = download_images(host, keys)
|
results = download_images(host, keys)
|
||||||
if results is None:
|
if results is None or len(results) == 0:
|
||||||
raise ValueError("could not download image")
|
return TestResult.failed(test.name, "could not download image")
|
||||||
|
|
||||||
passed = True
|
passed = False
|
||||||
for i in range(len(results)):
|
for i in range(len(results)):
|
||||||
result = results[i]
|
result = results[i]
|
||||||
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||||
|
@ -476,14 +521,19 @@ def run_test(
|
||||||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||||
|
|
||||||
mse = find_mse(result, ref)
|
mse = find_mse(result, ref)
|
||||||
|
threshold = test.mse_threshold * mse_mult
|
||||||
|
|
||||||
if mse < (test.mse_threshold * mse_mult):
|
if mse < threshold:
|
||||||
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
|
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
|
||||||
|
passed = True
|
||||||
else:
|
else:
|
||||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)
|
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
|
||||||
passed = False
|
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
|
||||||
|
|
||||||
return passed
|
if passed:
|
||||||
|
return TestResult.passed(test.name)
|
||||||
|
else:
|
||||||
|
return TestResult.failed(test.name, "no images tested")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -504,24 +554,26 @@ def main():
|
||||||
passed = []
|
passed = []
|
||||||
failed = []
|
failed = []
|
||||||
for test in tests:
|
for test in tests:
|
||||||
test_passed = False
|
result = None
|
||||||
|
|
||||||
for _i in range(3):
|
for _i in range(3):
|
||||||
try:
|
try:
|
||||||
logger.info("starting test: %s", test.name)
|
logger.info("starting test: %s", test.name)
|
||||||
if run_test(args.host, test, mse_mult=args.mse):
|
result = run_test(args.host, test, mse_mult=args.mse)
|
||||||
|
if result.passed:
|
||||||
logger.info("test passed: %s", test.name)
|
logger.info("test passed: %s", test.name)
|
||||||
test_passed = True
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("test failed: %s", test.name)
|
logger.warning("test failed: %s", test.name)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error running test for %s", test.name)
|
logger.exception("error running test for %s", test.name)
|
||||||
|
result = TestResult.failed(test.name, "TODO: exception message")
|
||||||
|
|
||||||
if test_passed:
|
if result is not None:
|
||||||
passed.append(test.name)
|
if result.passed:
|
||||||
else:
|
passed.append(result)
|
||||||
failed.append(test.name)
|
else:
|
||||||
|
failed.append(result)
|
||||||
|
|
||||||
logger.info("%s of %s tests passed", len(passed), len(tests))
|
logger.info("%s of %s tests passed", len(passed), len(tests))
|
||||||
failed = list(set(failed))
|
failed = list(set(failed))
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.pipeline import ChainProgress
|
||||||
|
|
||||||
|
|
||||||
|
class ChainProgressTests(unittest.TestCase):
|
||||||
|
def test_accumulate_with_reset(self):
|
||||||
|
def parent(step, timestep, latents):
|
||||||
|
pass
|
||||||
|
|
||||||
|
progress = ChainProgress(parent)
|
||||||
|
progress(5, 1, None)
|
||||||
|
progress(0, 1, None)
|
||||||
|
progress(5, 1, None)
|
||||||
|
|
||||||
|
self.assertEqual(progress.get_total(), 10)
|
||||||
|
|
||||||
|
def test_start_value(self):
|
||||||
|
def parent(step, timestep, latents):
|
||||||
|
pass
|
||||||
|
|
||||||
|
progress = ChainProgress(parent, 5)
|
||||||
|
self.assertEqual(progress.get_total(), 5)
|
||||||
|
|
||||||
|
progress(10, 1, None)
|
||||||
|
self.assertEqual(progress.get_total(), 10)
|
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_grid import BlendGridStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
|
|
||||||
|
class BlendGridStageTests(unittest.TestCase):
|
||||||
|
def test_stage(self):
|
||||||
|
stage = BlendGridStage()
|
||||||
|
sources = StageResult(
|
||||||
|
images=[
|
||||||
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
Image.new("RGB", (64, 64), "white"),
|
||||||
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
Image.new("RGB", (64, 64), "white"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 5)
|
||||||
|
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
|
@ -0,0 +1,47 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import ImageParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||||
|
|
||||||
|
|
||||||
|
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_stage(self):
|
||||||
|
stage = BlendImg2ImgStage()
|
||||||
|
params = ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
"txt2img",
|
||||||
|
"euler-a",
|
||||||
|
"an astronaut eating a hamburger",
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
sources = StageResult(
|
||||||
|
images=[
|
||||||
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
|
|
@ -0,0 +1,23 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
|
|
||||||
|
class BlendLinearStageTests(unittest.TestCase):
|
||||||
|
def test_stage(self):
|
||||||
|
stage = BlendLinearStage()
|
||||||
|
sources = StageResult(
|
||||||
|
images=[
|
||||||
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
stage_source = Image.new("RGB", (64, 64), "white")
|
||||||
|
result = stage.run(
|
||||||
|
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
|
@ -0,0 +1,25 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_mask import BlendMaskStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
|
class BlendMaskStageTests(unittest.TestCase):
|
||||||
|
def test_empty(self):
|
||||||
|
stage = BlendMaskStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
stage_mask=Image.new("RGBA", (64, 64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,46 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.server.hacks import apply_patches
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import (
|
||||||
|
TEST_MODEL_CORRECTION_CODEFORMER,
|
||||||
|
test_device,
|
||||||
|
test_needs_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectCodeformerStageTests(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER])
|
||||||
|
def test_empty(self):
|
||||||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
stage = CorrectCodeformerStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
worker,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,44 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.server.hacks import apply_patches
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||||||
|
|
||||||
|
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||||
|
@test_needs_onnx_models([TEST_MODEL])
|
||||||
|
def test_empty(self):
|
||||||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
stage = CorrectGFPGANStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
worker,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,24 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.reduce_crop import ReduceCropStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceCropStageTests(unittest.TestCase):
|
||||||
|
def test_empty(self):
|
||||||
|
stage = ReduceCropStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
origin=Size(0, 0),
|
||||||
|
size=Size(128, 128),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,28 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceThumbnailStageTests(unittest.TestCase):
|
||||||
|
def test_empty(self):
|
||||||
|
stage_source = Image.new("RGB", (64, 64))
|
||||||
|
stage = ReduceThumbnailStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
origin=Size(0, 0),
|
||||||
|
size=Size(128, 128),
|
||||||
|
stage_source=stage_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue