Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards
This commit is contained in:
commit
17f28aba62
19
README.md
19
README.md
|
@ -17,12 +17,14 @@ with a CPU fallback capable of running on laptop-class machines.
|
|||
Please check out [the setup guide to get started](docs/setup-guide.md) and [the user guide for more
|
||||
details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
||||
|
||||
![txt2img with detailed knollingcase renders of a soldier in a cloudy alien jungle](./docs/readme-preview.png)
|
||||
![preview of txt2img tab using SDXL to generate ghostly astronauts eating weird hamburgers on an abandoned space station](./docs/readme-sdxl.png)
|
||||
|
||||
## Features
|
||||
|
||||
This is an incomplete list of new and interesting features, with links to the user guide:
|
||||
|
||||
- SDXL support
|
||||
- LCM support
|
||||
- hardware acceleration on both AMD and Nvidia
|
||||
- tested on CUDA, DirectML, and ROCm
|
||||
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
||||
|
@ -37,6 +39,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
|||
- [txt2img](docs/user-guide.md#txt2img-tab)
|
||||
- [img2img](docs/user-guide.md#img2img-tab)
|
||||
- [inpainting](docs/user-guide.md#inpaint-tab), with mask drawing and upload
|
||||
- [panorama](docs/user-guide.md#panorama-pipeline), for both SD v1.5 and SDXL
|
||||
- [upscaling](docs/user-guide.md#upscale-tab), with ONNX acceleration
|
||||
- [add and use your own models](docs/user-guide.md#adding-your-own-models)
|
||||
- [convert models from diffusers and SD checkpoints](docs/converting-models.md)
|
||||
|
@ -45,20 +48,24 @@ This is an incomplete list of new and interesting features, with links to the us
|
|||
- [permanent and prompt-based blending](docs/user-guide.md#permanently-blending-additional-networks)
|
||||
- [supports LoRA and LyCORIS weights](docs/user-guide.md#lora-tokens)
|
||||
- [supports Textual Inversion concepts and embeddings](docs/user-guide.md#textual-inversion-tokens)
|
||||
- each layer of the embeddings can be controlled and used individually
|
||||
- ControlNet
|
||||
- image filters for edge detection and other methods
|
||||
- with ONNX acceleration
|
||||
- highres mode
|
||||
- runs img2img on the results of the other pipelines
|
||||
- multiple iterations can produce 8k images and larger
|
||||
- [multi-stage](docs/user-guide.md#prompt-stages) and [region prompts](docs/user-guide.md#region-tokens)
|
||||
- seamlessly combine multiple prompts in the same image
|
||||
- provide prompts for different areas in the image and blend them together
|
||||
- change the prompt for highres mode and refine details without recursion
|
||||
- infinite prompt length
|
||||
- [with long prompt weighting](docs/user-guide.md#long-prompt-weighting)
|
||||
- expand and control Textual Inversions per-layer
|
||||
- [image blending mode](docs/user-guide.md#blend-tab)
|
||||
- combine images from history
|
||||
- upscaling and face correction
|
||||
- upscaling with Real ESRGAN or Stable Diffusion
|
||||
- face correction with CodeFormer or GFPGAN
|
||||
- upscaling and correction
|
||||
- upscaling with Real ESRGAN, SwinIR, and Stable Diffusion
|
||||
- face correction with CodeFormer and GFPGAN
|
||||
- [API server can be run remotely](docs/server-admin.md)
|
||||
- REST API can be served over HTTPS or HTTP
|
||||
- background processing for all image pipelines
|
||||
|
@ -66,7 +73,7 @@ This is an incomplete list of new and interesting features, with links to the us
|
|||
- OCI containers provided
|
||||
- for all supported hardware accelerators
|
||||
- includes both the API and GUI bundle in a single container
|
||||
- runs well on [RunPod](https://www.runpod.io/) and other GPU container hosting services
|
||||
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
|
||||
|
||||
## Contents
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"python.testing.unittestArgs": [
|
||||
"-v",
|
||||
"-s",
|
||||
"./tests",
|
||||
"-p",
|
||||
"test_*.py"
|
||||
],
|
||||
"python.testing.pytestEnabled": false,
|
||||
"python.testing.unittestEnabled": true
|
||||
}
|
17
api/Makefile
17
api/Makefile
|
@ -1,4 +1,4 @@
|
|||
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload
|
||||
.PHONY: ci check-venv pip pip-dev lint-check lint-fix test typecheck package package-dist package-upload style
|
||||
|
||||
onnx_env: ## create virtual env
|
||||
python -v venv onnx_env
|
||||
|
@ -18,9 +18,10 @@ pip-dev: check-venv
|
|||
|
||||
test:
|
||||
python -m coverage erase
|
||||
python -m coverage run -m unittest discover -s tests/
|
||||
python -m coverage run -m unittest discover -v -s tests/
|
||||
python -m coverage html -i
|
||||
python -m coverage xml -i
|
||||
python -m coverage report -i
|
||||
|
||||
package: package-dist package-upload
|
||||
|
||||
|
@ -32,13 +33,21 @@ package-upload:
|
|||
|
||||
lint-check:
|
||||
black --check onnx_web/
|
||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||
black --check tests/
|
||||
flake8 onnx_web
|
||||
flake8 tests
|
||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||
isort --check-only --skip __init__.py --filter-files tests
|
||||
|
||||
lint-fix:
|
||||
black onnx_web/
|
||||
isort --skip __init__.py --filter-files onnx_web
|
||||
black tests/
|
||||
flake8 onnx_web
|
||||
flake8 tests
|
||||
isort --skip __init__.py --filter-files onnx_web
|
||||
isort --skip __init__.py --filter-files tests
|
||||
|
||||
style: lint-fix
|
||||
|
||||
typecheck:
|
||||
mypy onnx_web
|
||||
|
|
|
@ -1,45 +1,2 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageParams
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
from .correct_gfpgan import CorrectGFPGANStage
|
||||
from .persist_disk import PersistDiskStage
|
||||
from .persist_s3 import PersistS3Stage
|
||||
from .reduce_crop import ReduceCropStage
|
||||
from .reduce_thumbnail import ReduceThumbnailStage
|
||||
from .source_noise import SourceNoiseStage
|
||||
from .source_s3 import SourceS3Stage
|
||||
from .source_txt2img import SourceTxt2ImgStage
|
||||
from .source_url import SourceURLStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
from .upscale_simple import UpscaleSimpleStage
|
||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||
from .upscale_swinir import UpscaleSwinIRStage
|
||||
|
||||
CHAIN_STAGES = {
|
||||
"blend-img2img": BlendImg2ImgStage,
|
||||
"blend-inpaint": UpscaleOutpaintStage,
|
||||
"blend-linear": BlendLinearStage,
|
||||
"blend-mask": BlendMaskStage,
|
||||
"correct-codeformer": CorrectCodeformerStage,
|
||||
"correct-gfpgan": CorrectGFPGANStage,
|
||||
"persist-disk": PersistDiskStage,
|
||||
"persist-s3": PersistS3Stage,
|
||||
"reduce-crop": ReduceCropStage,
|
||||
"reduce-thumbnail": ReduceThumbnailStage,
|
||||
"source-noise": SourceNoiseStage,
|
||||
"source-s3": SourceS3Stage,
|
||||
"source-txt2img": SourceTxt2ImgStage,
|
||||
"source-url": SourceURLStage,
|
||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||
"upscale-simple": UpscaleSimpleStage,
|
||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||
"upscale-swinir": UpscaleSwinIRStage,
|
||||
}
|
||||
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
||||
from .stages import * # NOQA
|
||||
|
|
|
@ -1,240 +1,39 @@
|
|||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from time import monotonic
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..errors import RetryException
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server.context import ServerContext
|
||||
from ..worker.context import WorkerContext
|
||||
from .result import StageResult
|
||||
|
||||
|
||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
class ChainProgress:
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
self.parent = parent
|
||||
self.step = start
|
||||
self.total = 0
|
||||
|
||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||
if step < self.step:
|
||||
# accumulate on resets
|
||||
self.total += self.step
|
||||
|
||||
self.step = step
|
||||
self.parent(self.get_total(), timestep, latents)
|
||||
|
||||
def get_total(self) -> int:
|
||||
return self.step + self.total
|
||||
|
||||
@classmethod
|
||||
def from_progress(cls, parent: ProgressCallback):
|
||||
start = parent.step if hasattr(parent, "step") else 0
|
||||
return ChainProgress(parent, start=start)
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
"""
|
||||
Run many stages in series, passing the image results from each to the next, and processing
|
||||
tiles as needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: Optional[List[PipelineStage]] = None,
|
||||
):
|
||||
"""
|
||||
Create a new pipeline that will run the given stages.
|
||||
"""
|
||||
self.stages = list(stages or [])
|
||||
|
||||
def append(self, stage: Optional[PipelineStage]):
|
||||
"""
|
||||
Append an additional stage to this pipeline.
|
||||
|
||||
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
||||
assemble the stage from loose arguments.
|
||||
"""
|
||||
if stage is not None:
|
||||
self.stages.append(stage)
|
||||
class BaseStage:
|
||||
max_tile = SizeChart.auto
|
||||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
callback: Optional[ProgressCallback],
|
||||
**kwargs
|
||||
) -> List[Image.Image]:
|
||||
return self(
|
||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||
)
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
_sources: StageResult,
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
raise NotImplementedError() # noqa
|
||||
|
||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
return self
|
||||
|
||||
def __call__(
|
||||
def steps(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs
|
||||
) -> List[Image.Image]:
|
||||
"""
|
||||
DEPRECATED: use `run` instead
|
||||
"""
|
||||
if callback is not None:
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
_params: ImageParams,
|
||||
_size: Size,
|
||||
) -> int:
|
||||
return 1 # noqa
|
||||
|
||||
start = monotonic()
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.info(
|
||||
"running pipeline on %s source images",
|
||||
len(sources),
|
||||
)
|
||||
else:
|
||||
sources = [None]
|
||||
logger.info("running pipeline without source images")
|
||||
|
||||
stage_sources = sources
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.name or stage_pipe.__class__.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
logger.debug(
|
||||
"running stage %s with %s source images, parameters: %s",
|
||||
name,
|
||||
len(stage_sources) - stage_sources.count(None),
|
||||
kwargs.keys(),
|
||||
)
|
||||
|
||||
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
||||
must_tile = any(
|
||||
[
|
||||
needs_tile(
|
||||
stage_pipe.max_tile,
|
||||
stage_params.tile_size,
|
||||
size=kwargs.get("size", None),
|
||||
source=source,
|
||||
)
|
||||
for source in stage_sources
|
||||
]
|
||||
)
|
||||
|
||||
tile = stage_params.tile_size
|
||||
if stage_pipe.max_tile > 0:
|
||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
|
||||
if must_tile:
|
||||
stage_outputs = []
|
||||
for source in stage_sources:
|
||||
logger.info(
|
||||
"image larger than tile size of %s, tiling stage",
|
||||
tile,
|
||||
)
|
||||
|
||||
def stage_tile(
|
||||
source_tile: Image.Image,
|
||||
tile_mask: Image.Image,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> Image.Image:
|
||||
for i in range(worker.retries):
|
||||
try:
|
||||
output_tile = stage_pipe.run(
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
[source_tile],
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
dims=dims,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-tile.png", output_tile)
|
||||
|
||||
return output_tile
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error while running stage pipeline for tile, retry %s of 3",
|
||||
i,
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([worker.get_device()])
|
||||
worker.retries = worker.retries - (i + 1)
|
||||
|
||||
raise RetryException("exhausted retries on tile")
|
||||
|
||||
output = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
source,
|
||||
tile,
|
||||
stage_params.outscale,
|
||||
[stage_tile],
|
||||
**kwargs,
|
||||
)
|
||||
stage_outputs.append(output)
|
||||
|
||||
stage_sources = stage_outputs
|
||||
else:
|
||||
logger.debug("image within tile size of %s, running stage", tile)
|
||||
for i in range(worker.retries):
|
||||
try:
|
||||
stage_outputs = stage_pipe.run(
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
stage_sources,
|
||||
callback=callback,
|
||||
**kwargs,
|
||||
)
|
||||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
||||
# does not like, so it throws
|
||||
stage_sources = stage_outputs
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error while running stage pipeline, retry %s of 3", i
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([worker.get_device()])
|
||||
worker.retries = worker.retries - (i + 1)
|
||||
|
||||
if worker.retries <= 0:
|
||||
raise RetryException("exhausted retries on stage")
|
||||
|
||||
logger.debug(
|
||||
"finished stage %s with %s results",
|
||||
name,
|
||||
len(stage_sources),
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-stage.png", stage_sources[0])
|
||||
|
||||
end = monotonic()
|
||||
duration = timedelta(seconds=(end - start))
|
||||
logger.info(
|
||||
"finished pipeline in %s with %s results",
|
||||
duration,
|
||||
len(stage_sources),
|
||||
)
|
||||
return stage_sources
|
||||
def outputs(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class BlendDenoiseStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
strength: int = 3,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("denoising source images")
|
||||
|
||||
results = []
|
||||
for source in sources.as_numpy():
|
||||
data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
|
||||
data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength)
|
||||
results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB))
|
||||
|
||||
return StageResult(arrays=results)
|
|
@ -0,0 +1,62 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class BlendGridStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
height: int,
|
||||
width: int,
|
||||
# rows: Optional[List[str]] = None,
|
||||
# columns: Optional[List[str]] = None,
|
||||
# title: Optional[str] = None,
|
||||
order: Optional[List[int]] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info("combining source images using grid layout")
|
||||
|
||||
images = sources.as_image()
|
||||
ref_image = images[0]
|
||||
size = Size(*ref_image.size)
|
||||
|
||||
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
|
||||
|
||||
# TODO: labels
|
||||
if order is None:
|
||||
order = range(len(images))
|
||||
|
||||
for i in range(len(order)):
|
||||
x = i % width
|
||||
y = i // width
|
||||
|
||||
n = order[i]
|
||||
output.paste(images[n], (x * size.width, y * size.height))
|
||||
|
||||
return StageResult(images=[*images, output])
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -10,13 +10,14 @@ from ..diffusers.utils import encode_prompt, parse_prompt, slice_prompt
|
|||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class BlendImg2ImgStage(BaseStage):
|
||||
max_tile = SizeChart.unlimited
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
@ -24,14 +25,14 @@ class BlendImg2ImgStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
strength: float,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
prompt_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
params = params.with_args(**kwargs)
|
||||
|
||||
# multi-stage prompting
|
||||
|
@ -52,7 +53,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
params,
|
||||
pipe_type,
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
embeddings=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
|
@ -65,7 +66,7 @@ class BlendImg2ImgStage(BaseStage):
|
|||
pipe_params["strength"] = strength
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
if params.is_lpw():
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
|
@ -81,11 +82,10 @@ class BlendImg2ImgStage(BaseStage):
|
|||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
if not params.is_xl():
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
|
||||
if not params.is_xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -102,4 +102,18 @@ class BlendImg2ImgStage(BaseStage):
|
|||
|
||||
outputs.extend(result.images)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
*args,
|
||||
) -> int:
|
||||
return params.steps # TODO: multiply by strength
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,13 +19,18 @@ class BlendLinearStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
alpha: float,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
logger.info("blending source images using linear interpolation")
|
||||
|
||||
return [Image.blend(source, stage_source, alpha) for source in sources]
|
||||
return StageResult(
|
||||
images=[
|
||||
Image.blend(source, stage_source, alpha)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -8,7 +8,8 @@ from ..params import ImageParams, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,16 +21,17 @@ class BlendMaskStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
logger.info("blending image using mask")
|
||||
|
||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
||||
# TODO: does this need an alpha channel?
|
||||
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
||||
mult_mask.alpha_composite(stage_mask)
|
||||
mult_mask = mult_mask.convert("L")
|
||||
|
||||
|
@ -37,4 +39,9 @@ class BlendMaskStage(BaseStage):
|
|||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
return [Image.composite(stage_source, source, mult_mask) for source in sources]
|
||||
return StageResult(
|
||||
images=[
|
||||
Image.composite(stage_source, source, mult_mask)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,12 +19,12 @@ class CorrectCodeformerStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
# must be within the load function for patch to take effect
|
||||
# TODO: rewrite and remove
|
||||
from codeformer import CodeFormer
|
||||
|
@ -32,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
|
|||
|
||||
device = worker.get_device()
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||
return [pipe(source) for source in sources]
|
||||
return StageResult(images=[pipe(source) for source in sources.as_image()])
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -57,12 +57,12 @@ class CorrectGFPGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
if upscale.correction_model is None:
|
||||
|
@ -73,16 +73,15 @@ class CorrectGFPGANStage(BaseStage):
|
|||
device = worker.get_device()
|
||||
gfpgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
output = np.array(source)
|
||||
_, _, output = gfpgan.enhance(
|
||||
output,
|
||||
outputs = [
|
||||
gfpgan.enhance(
|
||||
source,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=upscale.face_strength,
|
||||
)
|
||||
outputs.append(Image.fromarray(output, "RGB"))
|
||||
for source in sources.as_numpy()
|
||||
]
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from ..chain.base import ChainPipeline
|
||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||
from ..chain.upscale import stage_upscale_correction
|
||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from .pipeline import ChainPipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ def stage_highres(
|
|||
outscale=highres.scale,
|
||||
),
|
||||
chain=chain,
|
||||
overlap=params.overlap,
|
||||
overlap=params.vae_overlap,
|
||||
)
|
||||
else:
|
||||
logger.debug("using simple upscaling for highres")
|
||||
|
@ -51,14 +51,14 @@ def stage_highres(
|
|||
UpscaleSimpleStage(),
|
||||
stage,
|
||||
method=highres.method,
|
||||
overlap=params.overlap,
|
||||
overlap=params.vae_overlap,
|
||||
upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale),
|
||||
)
|
||||
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
overlap=params.overlap,
|
||||
stage.with_args(outscale=1),
|
||||
overlap=params.vae_overlap,
|
||||
prompt_index=prompt_index + i,
|
||||
strength=highres.strength,
|
||||
)
|
||||
|
|
|
@ -1,33 +1,38 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class PersistDiskStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
output: str,
|
||||
stage_source: Image.Image,
|
||||
output: List[str],
|
||||
size: Optional[Size] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
for source in sources:
|
||||
# TODO: append index to output name
|
||||
dest = save_image(server, output, source, params=params)
|
||||
) -> StageResult:
|
||||
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||
|
||||
for source, name in zip(sources.as_image(), output):
|
||||
dest = save_image(server, name, source, params=params, size=size)
|
||||
logger.info("saved image to %s", dest)
|
||||
|
||||
return sources
|
||||
|
|
|
@ -8,7 +8,8 @@ from PIL import Image
|
|||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,26 +21,26 @@ class PersistS3Stage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
output: str,
|
||||
output: List[str],
|
||||
bucket: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
for source in sources:
|
||||
for source, name in zip(sources.as_image(), output):
|
||||
data = BytesIO()
|
||||
source.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
|
||||
try:
|
||||
s3.upload_fileobj(data, bucket, output)
|
||||
logger.info("saved image to s3://%s/%s", bucket, output)
|
||||
s3.upload_fileobj(data, bucket, name)
|
||||
logger.info("saved image to s3://%s/%s", bucket, name)
|
||||
except Exception:
|
||||
logger.exception("error saving image to S3")
|
||||
|
||||
|
|
|
@ -0,0 +1,281 @@
|
|||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from time import monotonic
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..errors import CancelledException, RetryException
|
||||
from ..output import save_image
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
class ChainProgress:
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
self.parent = parent
|
||||
self.step = start
|
||||
self.total = 0
|
||||
|
||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||
if step < self.step:
|
||||
# accumulate on resets
|
||||
self.total += self.step
|
||||
|
||||
self.step = step
|
||||
self.parent(self.get_total(), timestep, latents)
|
||||
|
||||
def get_total(self) -> int:
|
||||
return self.step + self.total
|
||||
|
||||
@classmethod
|
||||
def from_progress(cls, parent: ProgressCallback):
|
||||
start = parent.step if hasattr(parent, "step") else 0
|
||||
return ChainProgress(parent, start=start)
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
"""
|
||||
Run many stages in series, passing the image results from each to the next, and processing
|
||||
tiles as needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: Optional[List[PipelineStage]] = None,
|
||||
):
|
||||
"""
|
||||
Create a new pipeline that will run the given stages.
|
||||
"""
|
||||
self.stages = list(stages or [])
|
||||
|
||||
def append(self, stage: Optional[PipelineStage]):
|
||||
"""
|
||||
Append an additional stage to this pipeline.
|
||||
|
||||
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
||||
assemble the stage from loose arguments.
|
||||
"""
|
||||
if stage is not None:
|
||||
self.stages.append(stage)
|
||||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
callback: Optional[ProgressCallback],
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
result = self(
|
||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||
)
|
||||
return result.as_image()
|
||||
|
||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
return self
|
||||
|
||||
def steps(self, params: ImageParams, size: Size) -> int:
|
||||
steps = 0
|
||||
for callback, _params, kwargs in self.stages:
|
||||
steps += callback.steps(kwargs.get("params", params), size)
|
||||
|
||||
return steps
|
||||
|
||||
def outputs(self, params: ImageParams, sources: int) -> int:
|
||||
outputs = sources
|
||||
for callback, _params, kwargs in self.stages:
|
||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs,
|
||||
) -> StageResult:
|
||||
"""
|
||||
DEPRECATED: use `.run()` instead
|
||||
"""
|
||||
if callback is None:
|
||||
callback = worker.get_progress_callback()
|
||||
else:
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
|
||||
start = monotonic()
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.info(
|
||||
"running pipeline on %s source images",
|
||||
len(sources),
|
||||
)
|
||||
else:
|
||||
logger.info("running pipeline without source images")
|
||||
|
||||
stage_sources = sources
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.name or stage_pipe.__class__.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
logger.debug(
|
||||
"running stage %s with %s source images, parameters: %s",
|
||||
name,
|
||||
len(stage_sources),
|
||||
kwargs.keys(),
|
||||
)
|
||||
|
||||
per_stage_params = params
|
||||
if "params" in kwargs:
|
||||
per_stage_params = kwargs["params"]
|
||||
kwargs.pop("params")
|
||||
|
||||
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
||||
must_tile = has_mask(stage_kwargs) or any(
|
||||
[
|
||||
needs_tile(
|
||||
stage_pipe.max_tile,
|
||||
stage_params.tile_size,
|
||||
size=kwargs.get("size", None),
|
||||
source=source,
|
||||
)
|
||||
for source in stage_sources.as_image()
|
||||
]
|
||||
)
|
||||
|
||||
tile = stage_params.tile_size
|
||||
if stage_pipe.max_tile > 0:
|
||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
|
||||
if must_tile:
|
||||
logger.info(
|
||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||
tile,
|
||||
)
|
||||
|
||||
def stage_tile(
|
||||
source_tile: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> List[Image.Image]:
|
||||
for _i in range(worker.retries):
|
||||
try:
|
||||
tile_result = stage_pipe.run(
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
per_stage_params,
|
||||
StageResult(images=source_tile),
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
dims=dims,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
for j, image in enumerate(tile_result.as_image()):
|
||||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
return tile_result
|
||||
except CancelledException as err:
|
||||
worker.retries = 0
|
||||
logger.exception("job was cancelled while tiling")
|
||||
raise err
|
||||
except Exception:
|
||||
worker.retries = worker.retries - 1
|
||||
logger.exception(
|
||||
"error while running stage pipeline for tile, %s retries left",
|
||||
worker.retries,
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
raise RetryException("exhausted retries on tile")
|
||||
|
||||
stage_results = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
stage_sources,
|
||||
tile,
|
||||
stage_params.outscale,
|
||||
[stage_tile],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
stage_sources = StageResult(images=stage_results)
|
||||
else:
|
||||
logger.debug(
|
||||
"image does not contain sources and is within tile size of %s, running stage",
|
||||
tile,
|
||||
)
|
||||
for _i in range(worker.retries):
|
||||
try:
|
||||
stage_result = stage_pipe.run(
|
||||
worker,
|
||||
server,
|
||||
stage_params,
|
||||
per_stage_params,
|
||||
stage_sources,
|
||||
callback=callback,
|
||||
dims=(0, 0, tile),
|
||||
**kwargs,
|
||||
)
|
||||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
|
||||
# does not like, so it throws
|
||||
stage_sources = stage_result
|
||||
break
|
||||
except CancelledException as err:
|
||||
worker.retries = 0
|
||||
logger.exception("job was cancelled during stage")
|
||||
raise err
|
||||
except Exception:
|
||||
worker.retries = worker.retries - 1
|
||||
logger.exception(
|
||||
"error while running stage pipeline, %s retries left",
|
||||
worker.retries,
|
||||
)
|
||||
server.cache.clear()
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
if worker.retries <= 0:
|
||||
raise RetryException("exhausted retries on stage")
|
||||
|
||||
logger.debug(
|
||||
"finished stage %s with %s results",
|
||||
name,
|
||||
len(stage_sources),
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
for j, image in enumerate(stage_sources.as_image()):
|
||||
save_image(server, f"last-stage-{j}.png", image)
|
||||
|
||||
end = monotonic()
|
||||
duration = timedelta(seconds=(end - start))
|
||||
logger.info(
|
||||
"finished pipeline in %s with %s results",
|
||||
duration,
|
||||
len(stage_sources),
|
||||
)
|
||||
return stage_sources
|
||||
|
||||
|
||||
MASK_KEYS = ["mask", "stage_mask", "tile_mask"]
|
||||
|
||||
|
||||
def has_mask(args: List[str]) -> bool:
|
||||
return any([key in args for key in MASK_KEYS])
|
|
@ -1,12 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,20 +19,20 @@ class ReduceCropStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
origin: Size,
|
||||
size: Size,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
outputs = []
|
||||
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,15 +18,15 @@ class ReduceThumbnailStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
size: Size,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
outputs = []
|
||||
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
image = source.copy()
|
||||
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
|
@ -37,4 +37,4 @@ class ReduceThumbnailStage(BaseStage):
|
|||
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class StageResult:
|
||||
"""
|
||||
Chain pipeline stage result.
|
||||
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||
they are expected.
|
||||
"""
|
||||
|
||||
arrays: Optional[List[np.ndarray]]
|
||||
images: Optional[List[Image.Image]]
|
||||
|
||||
@staticmethod
|
||||
def empty():
|
||||
return StageResult(images=[])
|
||||
|
||||
@staticmethod
|
||||
def from_arrays(arrays: List[np.ndarray]):
|
||||
return StageResult(arrays=arrays)
|
||||
|
||||
@staticmethod
|
||||
def from_images(images: List[Image.Image]):
|
||||
return StageResult(images=images)
|
||||
|
||||
def __init__(self, arrays=None, images=None) -> None:
|
||||
if arrays is not None and images is not None:
|
||||
raise ValueError("stages must only return one type of result")
|
||||
elif arrays is None and images is None:
|
||||
raise ValueError("stages must return results")
|
||||
|
||||
self.arrays = arrays
|
||||
self.images = images
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.arrays is not None:
|
||||
return len(self.arrays)
|
||||
elif self.images is not None:
|
||||
return len(self.images)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def as_numpy(self) -> List[np.ndarray]:
|
||||
if self.arrays is not None:
|
||||
return self.arrays
|
||||
elif self.images is not None:
|
||||
return [np.array(i) for i in self.images]
|
||||
else:
|
||||
return []
|
||||
|
||||
def as_image(self) -> List[Image.Image]:
|
||||
if self.images is not None:
|
||||
return self.images
|
||||
elif self.arrays is not None:
|
||||
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
if len(arr.shape) != 3:
|
||||
raise ValueError("unknown array format")
|
||||
|
||||
if arr.shape[-1] == 3:
|
||||
return "RGB"
|
||||
elif arr.shape[-1] == 4:
|
||||
return "RGBA"
|
||||
|
||||
raise ValueError("unknown image format")
|
|
@ -1,12 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List
|
||||
from typing import Callable, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,25 +19,34 @@ class SourceNoiseStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
size: Size,
|
||||
noise_source: Callable,
|
||||
stage_source: Image.Image,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
logger.info("generating image from noise source")
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.warning(
|
||||
"source images were passed to a noise stage and will be discarded"
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
|
||||
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||
for source in sources.as_image():
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -8,7 +8,8 @@ from PIL import Image
|
|||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,18 +21,23 @@ class SourceS3Stage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
_sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
source_keys: List[str],
|
||||
bucket: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
outputs = []
|
||||
if len(sources) > 0:
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = sources.as_image()
|
||||
for key in source_keys:
|
||||
try:
|
||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||
|
@ -43,4 +49,11 @@ class SourceS3Stage(BaseStage):
|
|||
except Exception:
|
||||
logger.exception("error loading image from S3")
|
||||
|
||||
return outputs
|
||||
return StageResult(outputs)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1 # TODO: len(source_keys)
|
||||
|
|
|
@ -3,26 +3,28 @@ from typing import Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..constants import LATENT_FACTOR
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import (
|
||||
encode_prompt,
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
parse_prompt,
|
||||
parse_reseed,
|
||||
slice_prompt,
|
||||
)
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class SourceTxt2ImgStage(BaseStage):
|
||||
max_tile = SizeChart.unlimited
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
@ -30,15 +32,15 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
_source: Image.Image,
|
||||
sources: StageResult,
|
||||
*,
|
||||
dims: Tuple[int, int, int],
|
||||
dims: Tuple[int, int, int] = None,
|
||||
size: Size,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
prompt_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> StageResult:
|
||||
params = params.with_args(**kwargs)
|
||||
size = size.with_args(**kwargs)
|
||||
|
||||
|
@ -47,31 +49,58 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
params = params.with_args(prompt=slice_prompt(params.prompt, prompt_index))
|
||||
|
||||
logger.info(
|
||||
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
||||
"generating image using txt2img, %s steps of %s: %s",
|
||||
params.steps,
|
||||
params.model,
|
||||
params.prompt,
|
||||
)
|
||||
|
||||
if "stage_source" in kwargs:
|
||||
logger.warning(
|
||||
"a source image was passed to a txt2img stage, and will be discarded"
|
||||
if len(sources):
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||
params
|
||||
)
|
||||
|
||||
if params.is_xl():
|
||||
tile_size = max(stage.tile_size, params.tiles)
|
||||
if params.is_panorama() or params.is_xl():
|
||||
tile_size = max(stage.tile_size, params.unet_tile)
|
||||
else:
|
||||
tile_size = params.tiles
|
||||
tile_size = params.unet_tile
|
||||
|
||||
# this works for panorama as well, because tile_size is already max(tile_size, *size)
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
# generate new latents or slice existing
|
||||
if latents is None:
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
latents = get_latents_from_seed(int(params.seed), latent_size, params.batch)
|
||||
else:
|
||||
latents = get_tile_latents(latents, params.seed, latent_size, dims)
|
||||
latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
|
||||
|
||||
# reseed latents as needed
|
||||
reseed_rng = np.random.RandomState(params.seed)
|
||||
prompt, reseed = parse_reseed(prompt)
|
||||
for top, left, bottom, right, region_seed in reseed:
|
||||
if region_seed == -1:
|
||||
region_seed = reseed_rng.random_integers(2**32 - 1)
|
||||
|
||||
logger.debug(
|
||||
"reseed latent region: [:, :, %s:%s, %s:%s] with %s",
|
||||
top,
|
||||
left,
|
||||
bottom,
|
||||
right,
|
||||
region_seed,
|
||||
)
|
||||
latents[
|
||||
:,
|
||||
:,
|
||||
top // LATENT_FACTOR : bottom // LATENT_FACTOR,
|
||||
left // LATENT_FACTOR : right // LATENT_FACTOR,
|
||||
] = get_latents_from_seed(
|
||||
region_seed, Size(right - left, bottom - top), params.batch
|
||||
)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
pipe = load_pipeline(
|
||||
|
@ -79,7 +108,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
params,
|
||||
pipe_type,
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
embeddings=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
|
@ -101,11 +130,14 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
if params.is_panorama() or params.is_xl():
|
||||
logger.debug(
|
||||
"prompt alternatives are not supported for panorama or SDXL"
|
||||
)
|
||||
else:
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
|
||||
if not params.is_xl():
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
@ -123,4 +155,21 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
callback=callback,
|
||||
)
|
||||
|
||||
return result.images
|
||||
outputs = sources.as_image()
|
||||
outputs.extend(result.images)
|
||||
logger.debug("produced %s outputs", len(outputs))
|
||||
return StageResult(images=outputs)
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
return params.steps
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
@ -8,7 +8,8 @@ from PIL import Image
|
|||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,20 +21,20 @@ class SourceURLStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
source_urls: List[str],
|
||||
stage_source: Image.Image,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
logger.info("loading image from URL source")
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.warning(
|
||||
"a source image was passed to a source stage, and will be discarded"
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = []
|
||||
outputs = sources.as_image()
|
||||
for url in source_urls:
|
||||
response = requests.get(url)
|
||||
output = Image.open(BytesIO(response.content))
|
||||
|
@ -41,4 +42,11 @@ class SourceURLStage(BaseStage):
|
|||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
params: ImageParams,
|
||||
sources: int,
|
||||
) -> int:
|
||||
return sources + 1
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server.context import ServerContext
|
||||
from ..worker.context import WorkerContext
|
||||
|
||||
|
||||
class BaseStage:
|
||||
max_tile = SizeChart.auto
|
||||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*args,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def steps(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,64 @@
|
|||
from logging import getLogger
|
||||
|
||||
from .base import BaseStage
|
||||
from .blend_denoise import BlendDenoiseStage
|
||||
from .blend_grid import BlendGridStage
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
from .correct_gfpgan import CorrectGFPGANStage
|
||||
from .persist_disk import PersistDiskStage
|
||||
from .persist_s3 import PersistS3Stage
|
||||
from .reduce_crop import ReduceCropStage
|
||||
from .reduce_thumbnail import ReduceThumbnailStage
|
||||
from .source_noise import SourceNoiseStage
|
||||
from .source_s3 import SourceS3Stage
|
||||
from .source_txt2img import SourceTxt2ImgStage
|
||||
from .source_url import SourceURLStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||
from .upscale_simple import UpscaleSimpleStage
|
||||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||
from .upscale_swinir import UpscaleSwinIRStage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CHAIN_STAGES = {
|
||||
"blend-denoise": BlendDenoiseStage,
|
||||
"blend-img2img": BlendImg2ImgStage,
|
||||
"blend-inpaint": UpscaleOutpaintStage,
|
||||
"blend-grid": BlendGridStage,
|
||||
"blend-linear": BlendLinearStage,
|
||||
"blend-mask": BlendMaskStage,
|
||||
"correct-codeformer": CorrectCodeformerStage,
|
||||
"correct-gfpgan": CorrectGFPGANStage,
|
||||
"persist-disk": PersistDiskStage,
|
||||
"persist-s3": PersistS3Stage,
|
||||
"reduce-crop": ReduceCropStage,
|
||||
"reduce-thumbnail": ReduceThumbnailStage,
|
||||
"source-noise": SourceNoiseStage,
|
||||
"source-s3": SourceS3Stage,
|
||||
"source-txt2img": SourceTxt2ImgStage,
|
||||
"source-url": SourceURLStage,
|
||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||
"upscale-simple": UpscaleSimpleStage,
|
||||
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||
"upscale-swinir": UpscaleSwinIRStage,
|
||||
}
|
||||
|
||||
|
||||
def add_stage(name: str, stage: BaseStage) -> bool:
|
||||
global CHAIN_STAGES
|
||||
|
||||
if name in CHAIN_STAGES:
|
||||
logger.warning("cannot replace stage: %s", name)
|
||||
return False
|
||||
else:
|
||||
CHAIN_STAGES[name] = stage
|
||||
return True
|
|
@ -2,13 +2,14 @@ import itertools
|
|||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from math import ceil
|
||||
from typing import List, Optional, Protocol, Tuple
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..image.noise_source import noise_source_histogram
|
||||
from ..params import Size, TileOrder
|
||||
from .result import StageResult
|
||||
|
||||
# from skimage.exposure import match_histograms
|
||||
|
||||
|
@ -16,12 +17,15 @@ from ..params import Size, TileOrder
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
|
||||
|
||||
|
||||
class TileCallback(Protocol):
|
||||
"""
|
||||
Definition for a tile job function.
|
||||
"""
|
||||
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||
"""
|
||||
Run this stage against a single tile.
|
||||
"""
|
||||
|
@ -32,6 +36,9 @@ def complete_tile(
|
|||
source: Image.Image,
|
||||
tile: int,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: clean up
|
||||
"""
|
||||
if source is None:
|
||||
return source
|
||||
|
||||
|
@ -50,6 +57,12 @@ def needs_tile(
|
|||
source: Optional[Image.Image] = None,
|
||||
) -> bool:
|
||||
tile = min(max_tile, stage_tile)
|
||||
logger.trace(
|
||||
"checking image tile dimensions: %s, %s, %s",
|
||||
tile,
|
||||
source.width > tile or source.height > tile if source is not None else False,
|
||||
size.width > tile or size.height > tile if size is not None else False,
|
||||
)
|
||||
|
||||
if source is not None:
|
||||
return source.width > tile or source.height > tile
|
||||
|
@ -60,7 +73,7 @@ def needs_tile(
|
|||
return False
|
||||
|
||||
|
||||
def get_tile_grads(
|
||||
def make_tile_grads(
|
||||
left: int,
|
||||
top: int,
|
||||
tile: int,
|
||||
|
@ -85,6 +98,60 @@ def get_tile_grads(
|
|||
return (grad_x, grad_y)
|
||||
|
||||
|
||||
def make_tile_mask(
|
||||
shape: Any,
|
||||
tile: Tuple[int, int],
|
||||
overlap: float,
|
||||
edges: Tuple[bool, bool, bool, bool],
|
||||
) -> np.ndarray:
|
||||
mask = np.ones(shape)
|
||||
|
||||
tile_h, tile_w = tile
|
||||
|
||||
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
|
||||
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
|
||||
|
||||
# sort gradient points
|
||||
p1_h = adj_tile_h - 1
|
||||
p2_h = tile_h - adj_tile_h
|
||||
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
|
||||
|
||||
p1_w = adj_tile_w - 1
|
||||
p2_w = tile_w - adj_tile_w
|
||||
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
|
||||
|
||||
# build gradients
|
||||
edge_t, edge_l, edge_b, edge_r = edges
|
||||
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
|
||||
int(not edge_t),
|
||||
1,
|
||||
1,
|
||||
int(not edge_b),
|
||||
]
|
||||
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
|
||||
|
||||
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
|
||||
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
|
||||
|
||||
mask = ((mask * mult_x).T * mult_y).T
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
|
||||
if isinstance(image, np.ndarray):
|
||||
return image.shape[-1]
|
||||
|
||||
if image.mode == "RGBA":
|
||||
return 4
|
||||
elif image.mode == "RGB":
|
||||
return 3
|
||||
elif image.mode == "L":
|
||||
return 1
|
||||
|
||||
raise ValueError("unknown image format")
|
||||
|
||||
|
||||
def blend_tiles(
|
||||
tiles: List[Tuple[int, int, Image.Image]],
|
||||
scale: int,
|
||||
|
@ -98,23 +165,24 @@ def blend_tiles(
|
|||
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
|
||||
)
|
||||
|
||||
scaled_size = (height * scale, width * scale, 3)
|
||||
channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
|
||||
scaled_size = (height * scale, width * scale, channels)
|
||||
|
||||
count = np.zeros(scaled_size)
|
||||
value = np.zeros(scaled_size)
|
||||
|
||||
for left, top, tile_image in tiles:
|
||||
# histogram equalization
|
||||
equalized = np.array(tile_image).astype(np.float32)
|
||||
mask = np.ones_like(equalized[:, :, 0])
|
||||
|
||||
if adj_tile < tile:
|
||||
# sort gradient points
|
||||
p1 = adj_tile * scale
|
||||
p2 = (tile - adj_tile) * scale
|
||||
points = [0, min(p1, p2), max(p1, p2), tile * scale]
|
||||
p1 = (adj_tile * scale) - 1
|
||||
p2 = (tile - adj_tile - 1) * scale
|
||||
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
||||
|
||||
# gradient blending
|
||||
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
|
||||
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
|
||||
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
||||
|
||||
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
||||
|
@ -169,7 +237,7 @@ def blend_tiles(
|
|||
margin_left : equalized.shape[1] + margin_right,
|
||||
np.newaxis,
|
||||
],
|
||||
3,
|
||||
channels,
|
||||
axis=2,
|
||||
)
|
||||
|
||||
|
@ -178,60 +246,18 @@ def blend_tiles(
|
|||
return Image.fromarray(np.uint8(pixels))
|
||||
|
||||
|
||||
def process_tile_grid(
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
overlap: float = 0.0,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = kwargs.get("size", source.size if source else None)
|
||||
|
||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||
tiles_x = ceil(width / adj_tile)
|
||||
tiles_y = ceil(height / adj_tile)
|
||||
total = tiles_x * tiles_y
|
||||
logger.debug(
|
||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||
total,
|
||||
tiles_x,
|
||||
tiles_y,
|
||||
adj_tile,
|
||||
overlap,
|
||||
)
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
idx = (y * tiles_x) + x
|
||||
left = x * adj_tile
|
||||
top = y * adj_tile
|
||||
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
||||
|
||||
tile_image = (
|
||||
source.crop((left, top, left + tile, top + tile)) if source else None
|
||||
)
|
||||
tile_image = complete_tile(tile_image, tile)
|
||||
|
||||
for filter in filters:
|
||||
tile_image = filter(tile_image, (left, top, tile))
|
||||
|
||||
tiles.append((left, top, tile_image))
|
||||
|
||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
||||
|
||||
|
||||
def process_tile_spiral(
|
||||
source: Image.Image,
|
||||
def process_tile_stack(
|
||||
stack: StageResult,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
tile_generator: TileGenerator,
|
||||
overlap: float = 0.5,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = kwargs.get("size", source.size if source else None)
|
||||
) -> List[Image.Image]:
|
||||
sources = stack.as_image()
|
||||
|
||||
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||
mask = kwargs.get("mask", None)
|
||||
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
||||
fill_color = kwargs.get("fill_color", None)
|
||||
|
@ -239,18 +265,10 @@ def process_tile_spiral(
|
|||
tile_mask = None
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
tile_coords = tile_generator(width, height, tile, overlap)
|
||||
single_tile = len(tile_coords) == 1
|
||||
|
||||
# tile tuples is source, multiply by scale for dest
|
||||
counter = 0
|
||||
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
|
||||
|
||||
if len(tile_coords) == 1:
|
||||
single_tile = True
|
||||
else:
|
||||
single_tile = False
|
||||
|
||||
for left, top in tile_coords:
|
||||
counter += 1
|
||||
for counter, (left, top) in enumerate(tile_coords):
|
||||
logger.info(
|
||||
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
||||
)
|
||||
|
@ -274,37 +292,33 @@ def process_tile_spiral(
|
|||
needs_margin = True
|
||||
bottom_margin = height - bottom
|
||||
|
||||
# if no source given, we don't have a source image
|
||||
if not source:
|
||||
tile_image = None
|
||||
elif needs_margin:
|
||||
# in the special case where the image is smaller than the specified tile size, just use the image
|
||||
if single_tile:
|
||||
logger.debug("creating and processing single-tile subtile")
|
||||
tile_image = source
|
||||
logger.debug("using single tile")
|
||||
tile_stack = sources
|
||||
if mask:
|
||||
tile_mask = mask
|
||||
# otherwise use add histogram noise outside of the image border
|
||||
else:
|
||||
elif needs_margin:
|
||||
logger.debug(
|
||||
"tiling and adding margins: %s, %s, %s, %s",
|
||||
"tiling with added margins: %s, %s, %s, %s",
|
||||
left_margin,
|
||||
top_margin,
|
||||
right_margin,
|
||||
bottom_margin,
|
||||
)
|
||||
base_image = source.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
tile_stack = add_margin(
|
||||
stack.as_image(),
|
||||
left,
|
||||
top,
|
||||
right,
|
||||
bottom,
|
||||
left_margin,
|
||||
top_margin,
|
||||
right_margin,
|
||||
bottom_margin,
|
||||
tile,
|
||||
noise_source,
|
||||
fill_color,
|
||||
)
|
||||
)
|
||||
tile_image = noise_source(
|
||||
base_image, (tile, tile), (0, 0), fill=fill_color
|
||||
)
|
||||
tile_image.paste(base_image, (left_margin, top_margin))
|
||||
|
||||
if mask:
|
||||
base_mask = mask.crop(
|
||||
|
@ -320,38 +334,55 @@ def process_tile_spiral(
|
|||
|
||||
else:
|
||||
logger.debug("tiling normally")
|
||||
tile_image = source.crop((left, top, right, bottom))
|
||||
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
||||
if mask:
|
||||
tile_mask = mask.crop((left, top, right, bottom))
|
||||
|
||||
for image_filter in filters:
|
||||
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
||||
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||
|
||||
tiles.append((left, top, tile_image))
|
||||
if isinstance(tile_stack, list):
|
||||
tile_stack = StageResult.from_images(tile_stack)
|
||||
|
||||
if single_tile:
|
||||
return tile_image
|
||||
else:
|
||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
||||
tiles.append((left, top, tile_stack.as_image()))
|
||||
|
||||
lefts, tops, stacks = list(zip(*tiles))
|
||||
coords = list(zip(lefts, tops))
|
||||
stacks = list(zip(*stacks))
|
||||
|
||||
result = []
|
||||
for stack in stacks:
|
||||
stack_tiles = zip(coords, stack)
|
||||
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
|
||||
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_tile_order(
|
||||
order: TileOrder,
|
||||
source: Image.Image,
|
||||
stack: StageResult,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: needs to handle more than one image
|
||||
"""
|
||||
if order == TileOrder.grid:
|
||||
logger.debug("using grid tile order with tile size: %s", tile)
|
||||
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
||||
return process_tile_stack(
|
||||
stack, tile, scale, filters, generate_tile_grid, **kwargs
|
||||
)
|
||||
elif order == TileOrder.kernel:
|
||||
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||
raise NotImplementedError()
|
||||
elif order == TileOrder.spiral:
|
||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
||||
return process_tile_stack(
|
||||
stack, tile, scale, filters, generate_tile_spiral, **kwargs
|
||||
)
|
||||
else:
|
||||
logger.warning("unknown tile order: %s", order)
|
||||
raise ValueError()
|
||||
|
@ -445,3 +476,77 @@ def generate_tile_spiral(
|
|||
height_tile_target -= abs(state.value[1])
|
||||
|
||||
return tile_coords
|
||||
|
||||
|
||||
def generate_tile_grid(
|
||||
width: int,
|
||||
height: int,
|
||||
tile: int,
|
||||
overlap: float = 0.0,
|
||||
) -> List[Tuple[int, int]]:
|
||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||
tiles_x = ceil(width / adj_tile)
|
||||
tiles_y = ceil(height / adj_tile)
|
||||
total = tiles_x * tiles_y
|
||||
logger.debug(
|
||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||
total,
|
||||
tiles_x,
|
||||
tiles_y,
|
||||
adj_tile,
|
||||
overlap,
|
||||
)
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
left = x * adj_tile
|
||||
top = y * adj_tile
|
||||
|
||||
tiles.append((int(left), int(top)))
|
||||
|
||||
return tiles
|
||||
|
||||
|
||||
def get_result_tile(
|
||||
result: StageResult,
|
||||
origin: Tuple[int, int],
|
||||
tile: Size,
|
||||
) -> List[Image.Image]:
|
||||
top, left = origin
|
||||
return [
|
||||
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||
for layer in result.as_image()
|
||||
]
|
||||
|
||||
|
||||
def add_margin(
|
||||
stack: List[Image.Image],
|
||||
left: int,
|
||||
top: int,
|
||||
right: int,
|
||||
bottom: int,
|
||||
left_margin: int,
|
||||
top_margin: int,
|
||||
right_margin: int,
|
||||
bottom_margin: int,
|
||||
tile: int,
|
||||
noise_source,
|
||||
fill_color,
|
||||
) -> List[Image.Image]:
|
||||
results = []
|
||||
for source in stack:
|
||||
base_image = source.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
)
|
||||
)
|
||||
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
|
||||
tile_image.paste(base_image, (left_margin, top_margin))
|
||||
results.append(tile_image)
|
||||
|
||||
return results
|
||||
|
|
|
@ -1,22 +1,30 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UpscaleBSRGANStage(BaseStage):
|
||||
max_tile = 64
|
||||
max_tile = SizeChart.micro
|
||||
|
||||
def load(
|
||||
self,
|
||||
|
@ -54,12 +62,12 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
|
@ -71,40 +79,38 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
image = np.array(source) / 255.0
|
||||
for source in sources.as_numpy():
|
||||
image = source / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
logger.trace(
|
||||
"BSRGAN output shape: %s",
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
dest = bsrgan(image)
|
||||
output = bsrgan(image)
|
||||
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
|
||||
logger.debug("output image shape: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
return StageResult(arrays=outputs)
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
tile = min(params.tiles, self.max_tile)
|
||||
tile = min(params.unet_tile, self.max_tile)
|
||||
return size.width // tile * size.height // tile
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -8,7 +8,8 @@ from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
|||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from ..worker.context import ProgressCallback
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -20,20 +21,20 @@ class UpscaleHighresStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*args,
|
||||
sources: StageResult,
|
||||
*,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
if highres.scale <= 1:
|
||||
return sources
|
||||
|
||||
chain = stage_highres(stage, params, highres, upscale)
|
||||
|
||||
return [
|
||||
outputs = [
|
||||
chain(
|
||||
worker,
|
||||
server,
|
||||
|
@ -41,5 +42,7 @@ class UpscaleHighresStage(BaseStage):
|
|||
source,
|
||||
callback=callback,
|
||||
)
|
||||
for source in sources
|
||||
for source in sources.as_image()
|
||||
]
|
||||
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -18,13 +18,14 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
|||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UpscaleOutpaintStage(BaseStage):
|
||||
max_tile = SizeChart.unlimited
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
@ -32,7 +33,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
border: Border,
|
||||
dims: Tuple[int, int, int],
|
||||
|
@ -45,7 +46,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||
params
|
||||
)
|
||||
|
@ -56,12 +57,12 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
params,
|
||||
pipe_type,
|
||||
worker.get_device(),
|
||||
inversions=inversions,
|
||||
embeddings=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
@ -71,7 +72,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
outputs.append(source)
|
||||
continue
|
||||
|
||||
tile_size = params.tiles
|
||||
tile_size = params.unet_tile
|
||||
size = Size(*source.size)
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
|
@ -99,6 +100,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
if not params.is_xl():
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
|
@ -121,4 +123,4 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
|
||||
outputs.extend(result.images)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..onnx import OnnxRRDBNet
|
||||
|
@ -10,7 +9,8 @@ from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
|||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -77,25 +77,22 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
output = np.array(source)
|
||||
upsampler = self.load(
|
||||
server, upscale, worker.get_device(), tile=stage.tile_size
|
||||
)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, "RGB")
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs = []
|
||||
for source in sources.as_numpy():
|
||||
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
|
||||
logger.info("final output image size: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
return StageResult(arrays=outputs)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,13 +19,13 @@ class UpscaleSimpleStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
method: str,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
if upscale.scale <= 1:
|
||||
logger.debug(
|
||||
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
||||
|
@ -32,18 +33,20 @@ class UpscaleSimpleStage(BaseStage):
|
|||
return sources
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||
|
||||
if method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
outputs.append(
|
||||
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
)
|
||||
elif method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
outputs.append(
|
||||
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
)
|
||||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
|
||||
outputs.append(source)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
|
@ -10,7 +10,8 @@ from ..diffusers.utils import encode_prompt, parse_prompt
|
|||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -22,13 +23,13 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
params = params.with_args(**kwargs)
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
logger.info(
|
||||
|
@ -46,8 +47,9 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
worker.get_device(),
|
||||
model=path.join(server.model_path, upscale.upscale_model),
|
||||
)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
if not params.is_xl():
|
||||
prompt_embeds = encode_prompt(
|
||||
pipeline,
|
||||
prompt_pairs,
|
||||
|
@ -57,11 +59,11 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
pipeline.unet.set_prompts(prompt_embeds)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
for source in sources.as_image():
|
||||
result = pipeline(
|
||||
prompt,
|
||||
source,
|
||||
generator=generator,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
|
@ -71,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
)
|
||||
outputs.extend(result.images)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,22 +1,23 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UpscaleSwinIRStage(BaseStage):
|
||||
max_tile = 64
|
||||
max_tile = SizeChart.micro
|
||||
|
||||
def load(
|
||||
self,
|
||||
|
@ -54,12 +55,12 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
sources: StageResult,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
) -> StageResult:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
|
@ -71,31 +72,30 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
for source in sources.as_numpy():
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = np.array(source) / 255.0
|
||||
image = source / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("SwinIR input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
logger.trace(
|
||||
"SwinIR output shape: %s",
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
dest = swinir(image)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
output = swinir(image)
|
||||
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
logger.info("output image size: %s", output.shape)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,2 +1,5 @@
|
|||
ONNX_MODEL = "model.onnx"
|
||||
ONNX_WEIGHTS = "weights.pb"
|
||||
|
||||
LATENT_FACTOR = 8
|
||||
LATENT_CHANNELS = 4
|
||||
|
|
|
@ -15,7 +15,8 @@ from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
|||
from ..utils import load_config
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.control import convert_diffusion_control
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.diffusion import convert_diffusion_diffusers
|
||||
from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
|
||||
from .diffusion.lora import blend_loras
|
||||
from .diffusion.textual_inversion import blend_textual_inversions
|
||||
from .upscaling.bsrgan import convert_upscaling_bsrgan
|
||||
|
@ -357,6 +358,16 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
pipeline = model.get("pipeline", "txt2img")
|
||||
if pipeline.endswith("-sdxl"):
|
||||
converted, dest = convert_diffusion_diffusers_xl(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
else:
|
||||
converted, dest = convert_diffusion_diffusers(
|
||||
conversion,
|
||||
model,
|
||||
|
@ -588,7 +599,7 @@ def main(args=None) -> int:
|
|||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
server = ConversionContext.from_environ()
|
||||
server.half = args.half or "onnx-fp16" in server.optimizations
|
||||
server.half = args.half or server.has_optimization("onnx-fp16")
|
||||
server.opset = args.opset
|
||||
server.token = args.token
|
||||
logger.info(
|
||||
|
|
|
@ -20,7 +20,6 @@ from diffusers import (
|
|||
AutoencoderKL,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
|
@ -32,17 +31,25 @@ from onnx import load_model, save_model
|
|||
|
||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ...diffusers.load import optimize_pipeline
|
||||
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
from ...utils import run_gc
|
||||
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
||||
from ..utils import (
|
||||
RESOLVE_FORMATS,
|
||||
ConversionContext,
|
||||
check_ext,
|
||||
is_torch_2_0,
|
||||
load_tensor,
|
||||
onnx_export,
|
||||
)
|
||||
from .checkpoint import convert_extract_checkpoint
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
available_pipelines = {
|
||||
"controlnet": StableDiffusionControlNetPipeline,
|
||||
CONVERT_PIPELINES = {
|
||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||
"img2img": StableDiffusionPipeline,
|
||||
"inpaint": StableDiffusionPipeline,
|
||||
"lpw": StableDiffusionPipeline,
|
||||
|
@ -96,7 +103,6 @@ def get_model_version(
|
|||
opts["prediction_type"] = "epsilon"
|
||||
except Exception:
|
||||
logger.debug("unable to load tensor for version check")
|
||||
pass
|
||||
|
||||
return (v2, opts)
|
||||
|
||||
|
@ -314,7 +320,7 @@ def convert_diffusion_diffusers(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
pipe_class = available_pipelines.get(pipe_type)
|
||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
@ -360,7 +366,6 @@ def convert_diffusion_diffusers(
|
|||
source,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
vae_path=replace_vae,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif hf:
|
||||
|
@ -374,6 +379,17 @@ def convert_diffusion_diffusers(
|
|||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||
else:
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipeline.unet.set_attn_processor(AttnProcessor())
|
||||
pipeline.vae.set_attn_processor(AttnProcessor())
|
||||
|
||||
optimize_pipeline(conversion, pipeline)
|
||||
|
||||
output_path = Path(dest_path)
|
||||
|
@ -424,9 +440,6 @@ def convert_diffusion_diffusers(
|
|||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipeline.unet.set_attn_processor(AttnProcessor())
|
||||
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / ONNX_MODEL
|
||||
|
@ -526,19 +539,6 @@ def convert_diffusion_diffusers(
|
|||
del unet
|
||||
run_gc()
|
||||
|
||||
# VAE
|
||||
if replace_vae is not None:
|
||||
if replace_vae.startswith("."):
|
||||
logger.debug(
|
||||
"custom VAE appears to be a local path, making it relative to the model path"
|
||||
)
|
||||
replace_vae = path.join(conversion.model_path, replace_vae)
|
||||
|
||||
logger.info("loading custom VAE: %s", replace_vae)
|
||||
vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
pipeline.vae = vae
|
||||
run_gc()
|
||||
|
||||
if single_vae:
|
||||
logger.debug("VAE config: %s", pipeline.vae.config)
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
from optimum.exporters.onnx import main_export
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_diffusers_xl(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
replace_vae = model.get("vae", None)
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||
|
||||
dest_path = path.join(conversion.model_path, name)
|
||||
model_index = path.join(dest_path, "model_index.json")
|
||||
model_hash = path.join(dest_path, "hash.txt")
|
||||
|
||||
# diffusers go into a directory rather than .onnx file
|
||||
logger.info(
|
||||
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
|
||||
)
|
||||
|
||||
if path.exists(dest_path) and path.exists(model_index):
|
||||
logger.info("ONNX model already exists, skipping conversion")
|
||||
|
||||
if "hash" in model and not path.exists(model_hash):
|
||||
logger.info("ONNX model does not have hash file, adding one")
|
||||
with open(model_hash, "w") as f:
|
||||
f.write(model["hash"])
|
||||
|
||||
return (False, dest_path)
|
||||
|
||||
# safetensors -> diffusers directory with torch models
|
||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||
|
||||
if format == "safetensors":
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
source, use_safetensors=True
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||
else:
|
||||
logger.debug("loading pretrained VAE from path: %s", vae_path)
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
if path.exists(temp_path):
|
||||
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
||||
else:
|
||||
logger.debug("exporting torch model for %s: %s", source, temp_path)
|
||||
pipeline.save_pretrained(temp_path)
|
||||
|
||||
# directory -> onnx using optimum exporters
|
||||
main_export(
|
||||
temp_path,
|
||||
output=dest_path,
|
||||
task="stable-diffusion-xl",
|
||||
device=device,
|
||||
fp16=conversion.has_optimization(
|
||||
"torch-fp16"
|
||||
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
if "hash" in model:
|
||||
logger.debug("adding hash file to ONNX model")
|
||||
with open(model_hash, "w") as f:
|
||||
f.write(model["hash"])
|
||||
|
||||
if conversion.half:
|
||||
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
|
||||
infer_shapes_path(unet_path)
|
||||
unet = onnx.load(unet_path)
|
||||
opt_model = convert_float_to_float16(
|
||||
unet,
|
||||
disable_shape_infer=True,
|
||||
force_fp16_initializers=True,
|
||||
keep_io_types=True,
|
||||
op_block_list=["Attention", "MultiHeadAttention"],
|
||||
)
|
||||
onnx.save_model(
|
||||
opt_model,
|
||||
unet_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
)
|
||||
|
||||
return False, dest_path
|
|
@ -1,22 +1,15 @@
|
|||
from argparse import ArgumentParser
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import ModelProto, load, numpy_helper
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import (
|
||||
convert_model_to_external_data,
|
||||
set_external_data,
|
||||
write_external_data_tensors,
|
||||
)
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
from onnx import ModelProto, NodeProto, TensorProto, load, numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import OrtValue
|
||||
from scipy import interpolate
|
||||
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext, load_tensor
|
||||
from ..utils import load_tensor
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -39,7 +32,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|||
lr = a
|
||||
|
||||
if kernel == (1, 1):
|
||||
lr = np.expand_dims(lr, axis=(2, 3))
|
||||
lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis
|
||||
|
||||
return hr + lr
|
||||
|
||||
|
@ -78,13 +71,15 @@ def fix_node_name(key: str):
|
|||
return fixed_name
|
||||
|
||||
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]:
|
||||
fixed = {}
|
||||
names = [fix_node_name(node.name) for node in nodes]
|
||||
|
||||
for key, value in keys.items():
|
||||
root, *rest = key.split(".")
|
||||
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
|
||||
root, *_rest = key.split(".")
|
||||
logger.trace("fixing XL node name: %s -> %s", key, root)
|
||||
|
||||
simple = False
|
||||
if root.startswith("input"):
|
||||
block = "down_blocks"
|
||||
elif root.startswith("middle"):
|
||||
|
@ -93,6 +88,15 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
|||
block = "up_blocks"
|
||||
elif root.startswith("text_model"):
|
||||
block = "text_model"
|
||||
elif root.startswith("down_blocks"):
|
||||
block = "down_blocks"
|
||||
simple = True
|
||||
elif root.startswith("mid_block"):
|
||||
block = "mid_block"
|
||||
simple = True
|
||||
elif root.startswith("up_blocks"):
|
||||
block = "up_blocks"
|
||||
simple = True
|
||||
else:
|
||||
logger.warning("unknown XL key name: %s", key)
|
||||
fixed[key] = value
|
||||
|
@ -100,6 +104,10 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
|||
|
||||
suffix = None
|
||||
for s in [
|
||||
"conv",
|
||||
"conv_shortcut",
|
||||
"conv1",
|
||||
"conv2",
|
||||
"fc1",
|
||||
"fc2",
|
||||
"ff_net_0_proj",
|
||||
|
@ -119,18 +127,21 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
|||
logger.warning("new XL key type: %s", root)
|
||||
continue
|
||||
|
||||
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
||||
match = None
|
||||
if block == "text_model":
|
||||
match = next(
|
||||
node for node in nodes if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
)
|
||||
logger.trace("searching for XL node: %s -> /%s/*/%s", root, block, suffix)
|
||||
match: Optional[str] = None
|
||||
if "conv" in suffix:
|
||||
match = next(node for node in names if node == f"{root}_Conv")
|
||||
elif "time_emb_proj" in root:
|
||||
match = next(node for node in names if node == f"{root}_Gemm")
|
||||
elif block == "text_model" or simple:
|
||||
match = next(node for node in names if node == f"{root}_MatMul")
|
||||
else:
|
||||
# search in order. one side has sparse indices, so they will not match.
|
||||
match = next(
|
||||
node
|
||||
for node in nodes
|
||||
if node.name.startswith(f"/{block}")
|
||||
and fix_node_name(node.name).endswith(
|
||||
for node in names
|
||||
if node.startswith(block)
|
||||
and node.endswith(
|
||||
f"{suffix}_MatMul"
|
||||
) # needs to be fixed because some places use to_out.0
|
||||
)
|
||||
|
@ -138,18 +149,28 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
|||
if match is None:
|
||||
logger.warning("no matches for XL key: %s", root)
|
||||
continue
|
||||
else:
|
||||
logger.trace("matched key: %s -> %s", key, match)
|
||||
|
||||
name: str = match.name
|
||||
name = fix_node_name(name.rstrip("/MatMul"))
|
||||
name = match
|
||||
if name.endswith("_MatMul"):
|
||||
name = name[:-7]
|
||||
elif name.endswith("_Gemm"):
|
||||
name = name[:-5]
|
||||
elif name.endswith("_Conv"):
|
||||
name = name[:-5]
|
||||
|
||||
if name.endswith("proj_o"):
|
||||
# wtf
|
||||
name = f"{name}ut"
|
||||
|
||||
logger.debug("matching XL key with node: %s -> %s", key, match.name)
|
||||
logger.trace("matching XL key with node: %s -> %s, %s", key, match, name)
|
||||
|
||||
fixed[name] = value
|
||||
nodes.remove(match)
|
||||
names.remove(match)
|
||||
|
||||
logger.debug(
|
||||
"SDXL LoRA key fixup matched %s of %s keys, %s nodes remaining",
|
||||
len(fixed.keys()),
|
||||
len(keys.keys()),
|
||||
len(names),
|
||||
)
|
||||
|
||||
return fixed
|
||||
|
||||
|
@ -161,39 +182,9 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
|
|||
)
|
||||
|
||||
|
||||
def blend_loras(
|
||||
_conversion: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
model_index: Optional[int] = None,
|
||||
xl: Optional[bool] = False,
|
||||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
||||
if model_type == "text_encoder":
|
||||
if model_index is None:
|
||||
lora_prefix = "lora_te_"
|
||||
else:
|
||||
lora_prefix = f"lora_te{model_index}_"
|
||||
else:
|
||||
lora_prefix = f"lora_{model_type}_"
|
||||
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
if lora_model is None:
|
||||
logger.warning("unable to load tensor for LoRA")
|
||||
continue
|
||||
|
||||
for key in lora_model.keys():
|
||||
if ".hada_w1_a" in key and lora_prefix in key:
|
||||
# LoHA
|
||||
def blend_weights_loha(
|
||||
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||
) -> Tuple[str, np.ndarray]:
|
||||
base_key = key[: key.index(".hada_w1_a")].replace(lora_prefix, "")
|
||||
|
||||
t1_key = key.replace("hada_w1_a", "hada_t1")
|
||||
|
@ -260,26 +251,18 @@ def blend_loras(
|
|||
weights = (w1a_weight @ w1b_weight) * (w2a_weight @ w2b_weight)
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
|
||||
np_weights *= lora_weight
|
||||
if base_key in blended:
|
||||
logger.trace(
|
||||
"summing LoHA weights: %s + %s",
|
||||
blended[base_key].shape,
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] += sum_weights(blended[base_key], np_weights)
|
||||
else:
|
||||
blended[base_key] = np_weights
|
||||
elif ".lora_down" in key and lora_prefix in key:
|
||||
# LoRA or LoCON
|
||||
return base_key, np_weights
|
||||
|
||||
|
||||
def blend_weights_lora(
|
||||
key: str, lora_prefix: str, lora_model: Dict, dtype
|
||||
) -> Tuple[str, np.ndarray]:
|
||||
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||
|
||||
mid_key = key.replace("lora_down", "lora_mid")
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
logger.trace(
|
||||
"blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key
|
||||
)
|
||||
logger.trace("blending weights for LoRA keys: %s, %s, %s", key, up_key, alpha_key)
|
||||
|
||||
down_weight = lora_model[key].to(dtype=dtype)
|
||||
up_weight = lora_model[up_key].to(dtype=dtype)
|
||||
|
@ -320,10 +303,7 @@ def blend_loras(
|
|||
alpha,
|
||||
)
|
||||
weights = (
|
||||
(
|
||||
up_weight.squeeze(3).squeeze(2)
|
||||
@ down_weight.squeeze(3).squeeze(2)
|
||||
)
|
||||
(up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
|
||||
.unsqueeze(2)
|
||||
.unsqueeze(3)
|
||||
)
|
||||
|
@ -341,15 +321,12 @@ def blend_loras(
|
|||
mid_weight.shape,
|
||||
alpha,
|
||||
)
|
||||
weights = torch.zeros(
|
||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
||||
)
|
||||
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||
|
||||
for w in range(kernel[0]):
|
||||
for h in range(kernel[1]):
|
||||
weights[:, :, w, h] = (
|
||||
up_weight.squeeze(3).squeeze(2)
|
||||
@ mid_weight[:, :, w, h]
|
||||
up_weight.squeeze(3).squeeze(2) @ mid_weight[:, :, w, h]
|
||||
) @ down_weight.squeeze(3).squeeze(2)
|
||||
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
|
@ -361,9 +338,7 @@ def blend_loras(
|
|||
up_weight.shape,
|
||||
alpha,
|
||||
)
|
||||
weights = torch.zeros(
|
||||
(up_weight.shape[0], down_weight.shape[1], *kernel)
|
||||
)
|
||||
weights = torch.zeros((up_weight.shape[0], down_weight.shape[1], *kernel))
|
||||
|
||||
for w in range(kernel[0]):
|
||||
for h in range(kernel[1]):
|
||||
|
@ -371,8 +346,7 @@ def blend_loras(
|
|||
up_w, up_h = kernel_slice(w, h, up_weight.shape)
|
||||
|
||||
weights[:, :, w, h] = (
|
||||
up_weight[:, :, up_w, up_h]
|
||||
@ down_weight[:, :, down_w, down_h]
|
||||
up_weight[:, :, up_w, up_h] @ down_weight[:, :, down_w, down_h]
|
||||
)
|
||||
|
||||
np_weights = weights.numpy() * (alpha / dim)
|
||||
|
@ -382,48 +356,165 @@ def blend_loras(
|
|||
base_key,
|
||||
up_weight.shape[-2:],
|
||||
)
|
||||
# TODO: should this be None?
|
||||
np_weights = np.zeros((1, 1, 1, 1))
|
||||
|
||||
return base_key, np_weights
|
||||
|
||||
|
||||
def blend_node_conv_gemm(weight_node, weights) -> TensorProto:
|
||||
# blending
|
||||
onnx_weights = numpy_helper.to_array(weight_node)
|
||||
logger.trace(
|
||||
"found blended weights for conv: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
)
|
||||
|
||||
if onnx_weights.shape[-2:] == (1, 1):
|
||||
if weights.shape[-2:] == (1, 1):
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
else:
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||
|
||||
blended = np.expand_dims(blended, (2, 3))
|
||||
else:
|
||||
if onnx_weights.shape != weights.shape:
|
||||
logger.warning(
|
||||
"reshaping weights for mismatched Conv node: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
)
|
||||
# TODO: test if this can be replaced with interpolation, simply reshaping is pretty sus
|
||||
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||
else:
|
||||
blended = onnx_weights + weights
|
||||
|
||||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), weight_node.name)
|
||||
|
||||
|
||||
def blend_node_matmul(matmul_node, weights, matmul_key) -> TensorProto:
|
||||
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||
logger.trace(
|
||||
"found blended weights for matmul: %s, %s",
|
||||
weights.shape,
|
||||
onnx_weights.shape,
|
||||
)
|
||||
|
||||
t_weights = weights.transpose()
|
||||
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
||||
logger.warning(
|
||||
"weight shapes do not match for %s: %s vs %s",
|
||||
matmul_key,
|
||||
weights.shape,
|
||||
onnx_weights.shape,
|
||||
)
|
||||
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||
|
||||
blended = onnx_weights + t_weights
|
||||
logger.trace("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
||||
|
||||
# replace the original initializer
|
||||
return numpy_helper.from_array(blended.astype(onnx_weights.dtype), matmul_node.name)
|
||||
|
||||
|
||||
def blend_loras(
|
||||
_conversion: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
model_index: Optional[int] = None,
|
||||
xl: Optional[bool] = False,
|
||||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
||||
if model_type == "text_encoder":
|
||||
if model_index is None:
|
||||
lora_prefix = "lora_te_"
|
||||
else:
|
||||
lora_prefix = f"lora_te{model_index}_"
|
||||
else:
|
||||
lora_prefix = f"lora_{model_type}_"
|
||||
|
||||
layers = []
|
||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
if lora_model is None:
|
||||
logger.warning("unable to load tensor for LoRA")
|
||||
continue
|
||||
|
||||
np_weights *= lora_weight
|
||||
if base_key in blended:
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
layers.append(blended)
|
||||
|
||||
for key in lora_model.keys():
|
||||
if ".hada_w1_a" in key and lora_prefix in key:
|
||||
# LoHA
|
||||
base_key, np_weights = blend_weights_loha(
|
||||
key, lora_prefix, lora_model, dtype
|
||||
)
|
||||
np_weights = np_weights * lora_weight
|
||||
logger.trace(
|
||||
"summing weights: %s + %s",
|
||||
blended[base_key].shape,
|
||||
"adding LoHA weights: %s",
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] = np_weights
|
||||
elif ".lora_down" in key and lora_prefix in key:
|
||||
# LoRA or LoCON
|
||||
base_key, np_weights = blend_weights_lora(
|
||||
key, lora_prefix, lora_model, dtype
|
||||
)
|
||||
np_weights = np_weights * lora_weight
|
||||
logger.trace(
|
||||
"adding LoRA weights: %s",
|
||||
np_weights.shape,
|
||||
)
|
||||
blended[base_key] = sum_weights(blended[base_key], np_weights)
|
||||
else:
|
||||
blended[base_key] = np_weights
|
||||
|
||||
# rewrite node names for XL
|
||||
# rewrite node names for XL and flatten layers
|
||||
weights: Dict[str, np.ndarray] = {}
|
||||
|
||||
for blended in layers:
|
||||
if xl:
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
|
||||
logger.trace(
|
||||
"updating %s of %s initializers",
|
||||
len(blended.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
)
|
||||
for key, value in blended.items():
|
||||
if key in weights:
|
||||
weights[key] = sum_weights(weights[key], value)
|
||||
else:
|
||||
weights[key] = value
|
||||
|
||||
# fix node names once
|
||||
fixed_initializer_names = [
|
||||
fix_initializer_name(node.name) for node in base_model.graph.initializer
|
||||
]
|
||||
logger.trace("fixed initializer names: %s", fixed_initializer_names)
|
||||
|
||||
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||
logger.trace("fixed node names: %s", fixed_node_names)
|
||||
|
||||
logger.debug(
|
||||
"updating %s of %s initializers",
|
||||
len(weights.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
)
|
||||
|
||||
unmatched_keys = []
|
||||
for base_key, weights in blended.items():
|
||||
for base_key, weights in weights.items():
|
||||
conv_key = base_key + "_Conv"
|
||||
gemm_key = base_key + "_Gemm"
|
||||
matmul_key = base_key + "_MatMul"
|
||||
|
||||
logger.trace(
|
||||
"key %s has conv: %s, matmul: %s",
|
||||
"key %s has conv: %s, gemm: %s, matmul: %s",
|
||||
base_key,
|
||||
conv_key in fixed_node_names,
|
||||
gemm_key in fixed_node_names,
|
||||
matmul_key in fixed_node_names,
|
||||
)
|
||||
|
||||
|
@ -449,38 +540,9 @@ def blend_loras(
|
|||
weight_node = base_model.graph.initializer[weight_idx]
|
||||
logger.trace("found weight initializer: %s", weight_node.name)
|
||||
|
||||
# blending
|
||||
onnx_weights = numpy_helper.to_array(weight_node)
|
||||
logger.trace(
|
||||
"found blended weights for conv: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
)
|
||||
# replace the previous node
|
||||
updated_node = blend_node_conv_gemm(weight_node, weights)
|
||||
|
||||
if onnx_weights.shape[-2:] == (1, 1):
|
||||
if weights.shape[-2:] == (1, 1):
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
else:
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||
|
||||
blended = np.expand_dims(blended, (2, 3))
|
||||
else:
|
||||
if onnx_weights.shape != weights.shape:
|
||||
logger.warning(
|
||||
"reshaping weights for mismatched Conv node: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
)
|
||||
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||
else:
|
||||
blended = onnx_weights + weights
|
||||
|
||||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(onnx_weights.dtype), weight_node.name
|
||||
)
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||
elif matmul_key in fixed_node_names:
|
||||
|
@ -497,42 +559,15 @@ def blend_loras(
|
|||
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||
|
||||
# blending
|
||||
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||
logger.trace(
|
||||
"found blended weights for matmul: %s, %s",
|
||||
weights.shape,
|
||||
onnx_weights.shape,
|
||||
)
|
||||
# replace the previous node
|
||||
updated_node = blend_node_matmul(matmul_node, weights, matmul_key)
|
||||
|
||||
t_weights = weights.transpose()
|
||||
if (
|
||||
weights.shape != onnx_weights.shape
|
||||
and t_weights.shape != onnx_weights.shape
|
||||
):
|
||||
logger.warning(
|
||||
"weight shapes do not match for %s: %s vs %s",
|
||||
matmul_key,
|
||||
weights.shape,
|
||||
onnx_weights.shape,
|
||||
)
|
||||
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||
|
||||
blended = onnx_weights + t_weights
|
||||
logger.debug(
|
||||
"blended weight shape: %s, %s", blended.shape, onnx_weights.dtype
|
||||
)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(onnx_weights.dtype), matmul_node.name
|
||||
)
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
else:
|
||||
unmatched_keys.append(base_key)
|
||||
|
||||
logger.debug(
|
||||
logger.trace(
|
||||
"node counts: %s -> %s, %s -> %s",
|
||||
len(fixed_initializer_names),
|
||||
len(base_model.graph.initializer),
|
||||
|
@ -541,10 +576,7 @@ def blend_loras(
|
|||
)
|
||||
|
||||
if len(unmatched_keys) > 0:
|
||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
||||
|
||||
# if model_type == "unet":
|
||||
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
|
||||
logger.warning("could not find nodes for some LoRA keys: %s", unmatched_keys)
|
||||
|
||||
return base_model
|
||||
|
||||
|
@ -568,63 +600,3 @@ def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
|||
logger.debug("weights after interpolation: %s", output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--base", type=str)
|
||||
parser.add_argument("--dest", type=str)
|
||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
||||
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(
|
||||
"merging %s with %s with weights: %s",
|
||||
args.lora_models,
|
||||
args.base,
|
||||
args.lora_weights,
|
||||
)
|
||||
|
||||
default_weight = 1.0 / len(args.lora_models)
|
||||
while len(args.lora_weights) < len(args.lora_models):
|
||||
args.lora_weights.append(default_weight)
|
||||
|
||||
blend_model = blend_loras(
|
||||
context,
|
||||
args.base,
|
||||
list(zip(args.lora_models, args.lora_weights)),
|
||||
args.type,
|
||||
)
|
||||
if args.dest is None or args.dest == "" or args.dest == ":load":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||
logger.info("saved external data for %s nodes", len(external_data))
|
||||
|
||||
external_names, external_values = zip(*external_data)
|
||||
opts = SessionOptions()
|
||||
opts.add_external_initializers(list(external_names), list(external_values))
|
||||
sess = InferenceSession(
|
||||
bare_model.SerializeToString(),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
logger.info(
|
||||
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
||||
)
|
||||
else:
|
||||
convert_model_to_external_data(
|
||||
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
||||
)
|
||||
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||
|
||||
with open(dest_file, "w+b") as model_file:
|
||||
model_file.write(bare_model.SerializeToString())
|
||||
|
||||
logger.info("successfully saved blended model: %s", dest_file)
|
||||
|
||||
check_model(dest_file)
|
||||
|
||||
logger.info("checked blended model")
|
||||
|
|
|
@ -14,55 +14,23 @@ from ..utils import ConversionContext, load_tensor
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def blend_textual_inversions(
|
||||
server: ServerContext,
|
||||
text_encoder: ModelProto,
|
||||
tokenizer: CLIPTokenizer,
|
||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float32
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, inversion_format in inversions:
|
||||
if base_token is None:
|
||||
logger.debug("no base token provided, using name: %s", name)
|
||||
base_token = name
|
||||
|
||||
logger.info(
|
||||
"blending Textual Inversion %s with weight of %s for token %s",
|
||||
name,
|
||||
weight,
|
||||
base_token,
|
||||
)
|
||||
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
if inversion_format is None:
|
||||
def detect_embedding_format(loaded_embeds) -> str:
|
||||
keys: List[str] = list(loaded_embeds.keys())
|
||||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||
inversion_format = "concept"
|
||||
return "concept"
|
||||
elif "emb_params" in keys:
|
||||
logger.debug(
|
||||
"detected Textual Inversion parameter embeddings: %s", keys
|
||||
)
|
||||
inversion_format = "parameters"
|
||||
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
|
||||
return "parameters"
|
||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
||||
inversion_format = "embeddings"
|
||||
return "embeddings"
|
||||
else:
|
||||
logger.error(
|
||||
"unknown Textual Inversion format, no recognized keys: %s", keys
|
||||
)
|
||||
continue
|
||||
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
|
||||
return None
|
||||
|
||||
if inversion_format == "concept":
|
||||
|
||||
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
# separate token and the embeds
|
||||
token = list(loaded_embeds.keys())[0]
|
||||
|
||||
|
@ -78,11 +46,13 @@ def blend_textual_inversions(
|
|||
embeds[token] += layer
|
||||
else:
|
||||
embeds[token] = layer
|
||||
elif inversion_format == "parameters":
|
||||
|
||||
|
||||
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
emb_params = loaded_embeds["emb_params"]
|
||||
|
||||
num_tokens = emb_params.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||
|
||||
sum_layer = np.zeros(emb_params[0, :].shape)
|
||||
|
||||
|
@ -108,7 +78,9 @@ def blend_textual_inversions(
|
|||
embeds[sum_token] += sum_layer
|
||||
else:
|
||||
embeds[sum_token] = sum_layer
|
||||
elif inversion_format == "embeddings":
|
||||
|
||||
|
||||
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
||||
|
@ -117,7 +89,7 @@ def blend_textual_inversions(
|
|||
trained_embeds = string_to_param[token]
|
||||
|
||||
num_tokens = trained_embeds.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||
|
||||
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
||||
|
||||
|
@ -143,23 +115,9 @@ def blend_textual_inversions(
|
|||
embeds[sum_token] += sum_layer
|
||||
else:
|
||||
embeds[sum_token] = sum_layer
|
||||
else:
|
||||
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
||||
|
||||
# add the tokens to the tokenizer
|
||||
logger.debug(
|
||||
"found embeddings for %s tokens: %s",
|
||||
len(embeds.keys()),
|
||||
list(embeds.keys()),
|
||||
)
|
||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
logger.trace("added %s tokens", num_added_tokens)
|
||||
|
||||
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
|
||||
# resize the token embeddings
|
||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
embedding_node = [
|
||||
|
@ -191,6 +149,59 @@ def blend_textual_inversions(
|
|||
del text_encoder.graph.initializer[i]
|
||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def blend_textual_inversions(
|
||||
server: ServerContext,
|
||||
text_encoder: ModelProto,
|
||||
tokenizer: CLIPTokenizer,
|
||||
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float32
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, format in embeddings:
|
||||
if base_token is None:
|
||||
logger.debug("no base token provided, using name: %s", name)
|
||||
base_token = name
|
||||
|
||||
logger.info(
|
||||
"blending Textual Inversion %s with weight of %s for token %s",
|
||||
name,
|
||||
weight,
|
||||
base_token,
|
||||
)
|
||||
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
if format is None:
|
||||
format = detect_embedding_format(loaded_embeds)
|
||||
|
||||
if format == "concept":
|
||||
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
elif format == "parameters":
|
||||
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
elif format == "embeddings":
|
||||
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
else:
|
||||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||
|
||||
# add the tokens to the tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
logger.trace("added %s tokens", num_added_tokens)
|
||||
|
||||
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
|
||||
|
||||
return (text_encoder, tokenizer)
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
|
|||
class ConversionContext(ServerContext):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
model_path: str = ".",
|
||||
cache_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
half: bool = False,
|
||||
|
@ -69,7 +69,7 @@ class ConversionContext(ServerContext):
|
|||
def from_environ(cls):
|
||||
context = super().from_environ()
|
||||
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
|
||||
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", True)
|
||||
context.extract = get_boolean(environ, "ONNX_WEB_CONVERT_EXTRACT", False)
|
||||
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
|
||||
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
|
||||
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
|
||||
|
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
|||
|
||||
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
name, source, *_rest = model
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
|
@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
|||
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
scale = rest[0] if len(rest) > 0 else 1
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||
half = rest.pop(0) if len(rest) > 0 else False
|
||||
opset = rest.pop(0) if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
|
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
|||
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
single_vae = rest[0] if len(rest) > 0 else False
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
single_vae = rest.pop(0) if len(rest) > 0 else False
|
||||
half = rest.pop(0) if len(rest) > 0 else False
|
||||
opset = rest.pop(0) if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
|
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
|||
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
scale = rest[0] if len(rest) > 0 else 1
|
||||
half = rest[0] if len(rest) > 0 else False
|
||||
opset = rest[0] if len(rest) > 0 else None
|
||||
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||
half = rest.pop(0) if len(rest) > 0 else False
|
||||
opset = rest.pop(0) if len(rest) > 0 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
|
@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
|
||||
|
||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||
|
||||
|
||||
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
||||
_name, ext = path.splitext(name)
|
||||
ext = ext.strip(".")
|
||||
|
||||
return (ext in exts, ext)
|
||||
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
|
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
return model["format"]
|
||||
|
||||
if "source" in model:
|
||||
_name, ext = path.splitext(model["source"])
|
||||
if ext in MODEL_FORMATS:
|
||||
valid, ext = check_ext(model["source"], MODEL_FORMATS)
|
||||
if valid:
|
||||
return ext
|
||||
|
||||
return None
|
||||
|
@ -298,6 +305,7 @@ def onnx_export(
|
|||
half=False,
|
||||
external_data=False,
|
||||
v2=False,
|
||||
op_block_list=None,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
|
@ -316,8 +324,7 @@ def onnx_export(
|
|||
opset_version=opset,
|
||||
)
|
||||
|
||||
op_block_list = None
|
||||
if v2:
|
||||
if v2 and op_block_list is None:
|
||||
op_block_list = ["Attention", "MultiHeadAttention"]
|
||||
|
||||
if half:
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
|
||||
from onnx import load_model
|
||||
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
||||
ORTStableDiffusionXLImg2ImgPipeline,
|
||||
ORTStableDiffusionXLPipeline,
|
||||
)
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ..constants import ONNX_MODEL
|
||||
from ..constants import LATENT_FACTOR, ONNX_MODEL
|
||||
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
|
@ -24,6 +23,7 @@ from .patches.vae import VAEWrapper
|
|||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
||||
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
|
||||
from .version_safe_diffusers import (
|
||||
DDIMScheduler,
|
||||
|
@ -38,6 +38,7 @@ from .version_safe_diffusers import (
|
|||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
OnnxRuntimeModel,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
|
@ -58,6 +59,7 @@ available_pipelines = {
|
|||
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
|
||||
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
|
||||
"panorama": OnnxStableDiffusionPanoramaPipeline,
|
||||
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
|
||||
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
|
||||
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
|
||||
"txt2img": OnnxStableDiffusionPipeline,
|
||||
|
@ -77,12 +79,25 @@ pipeline_schedulers = {
|
|||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||
"karras-ve": KarrasVeScheduler,
|
||||
"lcm": LCMScheduler,
|
||||
"lms-discrete": LMSDiscreteScheduler,
|
||||
"pndm": PNDMScheduler,
|
||||
"unipc-multi": UniPCMultistepScheduler,
|
||||
}
|
||||
|
||||
|
||||
def add_pipeline(name: str, pipeline: Any) -> bool:
|
||||
global available_pipelines
|
||||
|
||||
if name in available_pipelines:
|
||||
# TODO: decide if this should be allowed or not
|
||||
logger.warning("cannot replace existing pipeline: %s", name)
|
||||
return False
|
||||
else:
|
||||
available_pipelines[name] = pipeline
|
||||
return True
|
||||
|
||||
|
||||
def get_available_pipelines() -> List[str]:
|
||||
return list(available_pipelines.keys())
|
||||
|
||||
|
@ -99,16 +114,19 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
VAE_COMPONENTS = ["vae", "vae_decoder", "vae_encoder"]
|
||||
|
||||
|
||||
def load_pipeline(
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
pipeline: str,
|
||||
device: DeviceParams,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
embeddings: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
inversions = inversions or []
|
||||
embeddings = embeddings or []
|
||||
loras = loras or []
|
||||
model = model or params.model
|
||||
|
||||
|
@ -122,7 +140,7 @@ def load_pipeline(
|
|||
device.device,
|
||||
device.provider,
|
||||
control_key,
|
||||
inversions,
|
||||
embeddings,
|
||||
loras,
|
||||
)
|
||||
scheduler_key = (params.scheduler, model)
|
||||
|
@ -159,26 +177,128 @@ def load_pipeline(
|
|||
run_gc([device])
|
||||
|
||||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
components = {
|
||||
"scheduler": scheduler_type.from_pretrained(
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
components = {
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
# shared components
|
||||
text_encoder = None
|
||||
unet_type = "unet"
|
||||
|
||||
# ControlNet component
|
||||
if params.is_control() and params.control is not None:
|
||||
cnet_path = path.join(
|
||||
server.model_path, "control", f"{params.control.name}.onnx"
|
||||
logger.debug("loading ControlNet components")
|
||||
control_components = load_controlnet(server, device, params)
|
||||
components.update(control_components)
|
||||
unet_type = "cnet"
|
||||
|
||||
# load various pipeline components
|
||||
encoder_components = load_text_encoders(
|
||||
server, device, model, embeddings, loras, torch_dtype, params
|
||||
)
|
||||
components.update(encoder_components)
|
||||
|
||||
unet_components = load_unet(server, device, model, loras, unet_type, params)
|
||||
components.update(unet_components)
|
||||
|
||||
vae_components = load_vae(server, device, model, params)
|
||||
components.update(vae_components)
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
|
||||
if params.is_xl():
|
||||
logger.debug("assembling SDXL pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class(
|
||||
components["vae_decoder_session"],
|
||||
components["text_encoder_session"],
|
||||
components["unet_session"],
|
||||
{
|
||||
"force_zeros_for_empty_prompt": True,
|
||||
"requires_aesthetics_score": False,
|
||||
},
|
||||
components["tokenizer"],
|
||||
scheduler,
|
||||
vae_encoder_session=components.get("vae_encoder_session", None),
|
||||
text_encoder_2_session=components.get("text_encoder_2_session", None),
|
||||
tokenizer_2=components.get("tokenizer_2", None),
|
||||
)
|
||||
else:
|
||||
if "vae" in components:
|
||||
# upscale uses a single VAE
|
||||
logger.debug(
|
||||
"assembling SD pipeline for %s with single VAE",
|
||||
pipeline_class.__name__,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
components["vae"],
|
||||
components["text_encoder"],
|
||||
components["tokenizer"],
|
||||
components["unet"],
|
||||
scheduler,
|
||||
scheduler,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"assembling SD pipeline for %s with VAE codec",
|
||||
pipeline_class.__name__,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
components["vae_encoder"],
|
||||
components["vae_decoder"],
|
||||
components["text_encoder"],
|
||||
components["tokenizer"],
|
||||
components["unet"],
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
optimize_pipeline(server, pipe)
|
||||
patch_pipeline(server, pipe, pipeline_class, params)
|
||||
|
||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
||||
|
||||
for vae in VAE_COMPONENTS:
|
||||
if hasattr(pipe, vae):
|
||||
vae_model = getattr(pipe, vae)
|
||||
if isinstance(vae_model, VAEWrapper):
|
||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||
vae_model.set_window_size(
|
||||
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||
)
|
||||
|
||||
# update panorama params
|
||||
if params.is_panorama():
|
||||
unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // LATENT_FACTOR
|
||||
logger.debug(
|
||||
"setting panorama window parameters: %s/%s for UNet, %s/%s for VAE",
|
||||
params.unet_tile,
|
||||
unet_stride,
|
||||
params.vae_tile,
|
||||
params.vae_overlap,
|
||||
)
|
||||
pipe.set_window_size(params.unet_tile // LATENT_FACTOR, unet_stride)
|
||||
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
|
||||
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||
components = {}
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
cnet_path,
|
||||
|
@ -186,63 +306,89 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
return components
|
||||
|
||||
unet_type = "cnet"
|
||||
|
||||
# Textual Inversion blending
|
||||
if inversions is not None and len(inversions) > 0:
|
||||
logger.debug("blending Textual Inversions from %s", inversions)
|
||||
inversion_names, inversion_weights = zip(*inversions)
|
||||
|
||||
inversion_models = [
|
||||
path.join(server.model_path, "inversion", name)
|
||||
for name in inversion_names
|
||||
]
|
||||
def load_text_encoders(
|
||||
server: ServerContext,
|
||||
device: DeviceParams,
|
||||
model: str,
|
||||
embeddings: Optional[List[Tuple[str, float]]],
|
||||
loras: Optional[List[Tuple[str, float]]],
|
||||
torch_dtype,
|
||||
params: ImageParams,
|
||||
):
|
||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
components = {
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer_2",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
components["tokenizer_2"] = tokenizer_2
|
||||
|
||||
# blend embeddings, if any
|
||||
if embeddings is not None and len(embeddings) > 0:
|
||||
embedding_names, embedding_weights = zip(*embeddings)
|
||||
embedding_models = [
|
||||
path.join(server.model_path, "inversion", name) for name in embedding_names
|
||||
]
|
||||
logger.debug(
|
||||
"blending base model %s with embeddings from %s", model, embedding_models
|
||||
)
|
||||
|
||||
# TODO: blend text_encoder_2 as well
|
||||
text_encoder, tokenizer = blend_textual_inversions(
|
||||
server,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
list(
|
||||
zip(
|
||||
inversion_models,
|
||||
inversion_weights,
|
||||
inversion_names,
|
||||
[None] * len(inversion_models),
|
||||
embedding_models,
|
||||
embedding_weights,
|
||||
embedding_names,
|
||||
[None] * len(embedding_models),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
components["tokenizer"] = tokenizer
|
||||
|
||||
# should be pretty small and should not need external data
|
||||
if loras is None or len(loras) == 0:
|
||||
# TODO: handle XL encoders
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
sess_options=device.sess_options(),
|
||||
if params.is_xl():
|
||||
text_encoder_2, tokenizer_2 = blend_textual_inversions(
|
||||
server,
|
||||
text_encoder_2,
|
||||
tokenizer_2,
|
||||
list(
|
||||
zip(
|
||||
embedding_models,
|
||||
embedding_weights,
|
||||
embedding_names,
|
||||
[None] * len(embedding_models),
|
||||
)
|
||||
),
|
||||
)
|
||||
components["tokenizer_2"] = tokenizer_2
|
||||
|
||||
# LoRA blending
|
||||
# blend LoRAs, if any
|
||||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", name) for name in lora_names
|
||||
]
|
||||
logger.info(
|
||||
"blending base model %s with LoRA models: %s", model, lora_models
|
||||
)
|
||||
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
|
||||
|
||||
# blend and load text encoder
|
||||
text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
|
||||
text_encoder = blend_loras(
|
||||
server,
|
||||
text_encoder,
|
||||
|
@ -251,34 +397,8 @@ def load_pipeline(
|
|||
1 if params.is_xl() else None,
|
||||
params.is_xl(),
|
||||
)
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||
text_encoder
|
||||
)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = device.sess_options(cache=False)
|
||||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_session = InferenceSession(
|
||||
text_encoder.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||
components["text_encoder_session"] = text_encoder_session
|
||||
else:
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
)
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||
text_encoder_2 = blend_loras(
|
||||
server,
|
||||
text_encoder_2,
|
||||
|
@ -287,6 +407,17 @@ def load_pipeline(
|
|||
2,
|
||||
params.is_xl(),
|
||||
)
|
||||
|
||||
# prepare external data for sessions
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = device.sess_options(cache=False)
|
||||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
|
||||
if params.is_xl():
|
||||
# encoder 2 only exists in XL
|
||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||
text_encoder_2
|
||||
)
|
||||
|
@ -296,6 +427,16 @@ def load_pipeline(
|
|||
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||
)
|
||||
|
||||
# session for te1
|
||||
text_encoder_session = InferenceSession(
|
||||
text_encoder.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||
components["text_encoder_session"] = text_encoder_session
|
||||
|
||||
# session for te2
|
||||
text_encoder_2_session = InferenceSession(
|
||||
text_encoder_2.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
|
@ -303,17 +444,48 @@ def load_pipeline(
|
|||
)
|
||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||
components["text_encoder_2_session"] = text_encoder_2_session
|
||||
else:
|
||||
# session for te
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
|
||||
def load_unet(
|
||||
server: ServerContext,
|
||||
device: DeviceParams,
|
||||
model: str,
|
||||
loras: List[Tuple[str, float]],
|
||||
unet_type: Literal["cnet", "unet"],
|
||||
params: ImageParams,
|
||||
):
|
||||
components = {}
|
||||
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
||||
|
||||
# LoRA blending
|
||||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", name) for name in lora_names
|
||||
]
|
||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||
|
||||
# blend and load unet
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
blended_unet = blend_loras(
|
||||
unet = blend_loras(
|
||||
server,
|
||||
unet,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"unet",
|
||||
xl=params.is_xl(),
|
||||
)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = device.sess_options(cache=False)
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
|
@ -335,23 +507,18 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# make sure a UNet has been loaded
|
||||
if not params.is_xl() and "unet" not in components:
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
unet,
|
||||
provider=device.ort_provider("unet"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
return components
|
||||
|
||||
|
||||
def load_vae(
|
||||
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||
):
|
||||
# one or more VAE models need to be loaded
|
||||
vae = path.join(model, "vae", ONNX_MODEL)
|
||||
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
||||
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)
|
||||
|
||||
components = {}
|
||||
if not params.is_xl() and path.exists(vae):
|
||||
logger.debug("loading VAE from %s", vae)
|
||||
components["vae"] = OnnxRuntimeModel(
|
||||
|
@ -361,9 +528,25 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
elif (
|
||||
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
|
||||
):
|
||||
elif path.exists(vae_decoder) and path.exists(vae_encoder):
|
||||
if params.is_xl():
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_decoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_decoder_session"]._model_path = vae_decoder
|
||||
|
||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||
vae_encoder,
|
||||
provider=device.ort_provider("vae"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["vae_encoder_session"]._model_path = vae_encoder
|
||||
|
||||
else:
|
||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||
components["vae_decoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -382,119 +565,44 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# additional options for panorama pipeline
|
||||
if params.is_panorama():
|
||||
components["window"] = params.tiles // 8
|
||||
components["stride"] = params.stride // 8
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
safety_checker=None,
|
||||
torch_dtype=torch_dtype,
|
||||
**components,
|
||||
)
|
||||
|
||||
# make sure XL models are actually being used
|
||||
# TODO: why is this needed?
|
||||
if "text_encoder_session" in components:
|
||||
logger.info(
|
||||
"text encoder matches: %s, %s",
|
||||
pipe.text_encoder.session == components["text_encoder_session"],
|
||||
type(pipe.text_encoder),
|
||||
)
|
||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||
|
||||
if "text_encoder_2_session" in components:
|
||||
logger.info(
|
||||
"text encoder 2 matches: %s, %s",
|
||||
pipe.text_encoder_2.session == components["text_encoder_2_session"],
|
||||
type(pipe.text_encoder_2),
|
||||
)
|
||||
pipe.text_encoder_2 = ORTModelTextEncoder(
|
||||
text_encoder_2_session, text_encoder_2
|
||||
)
|
||||
|
||||
if "unet_session" in components:
|
||||
logger.info(
|
||||
"unet matches: %s, %s",
|
||||
pipe.unet.session == components["unet_session"],
|
||||
type(pipe.unet),
|
||||
)
|
||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
optimize_pipeline(server, pipe)
|
||||
|
||||
if not params.is_xl():
|
||||
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
|
||||
|
||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||
|
||||
if not params.is_xl() and hasattr(pipe, "vae_decoder"):
|
||||
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)
|
||||
|
||||
if not params.is_xl() and hasattr(pipe, "vae_encoder"):
|
||||
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)
|
||||
|
||||
# update panorama params
|
||||
if params.is_panorama():
|
||||
latent_window = params.tiles // 8
|
||||
latent_stride = params.stride // 8
|
||||
|
||||
pipe.set_window_size(latent_window, latent_stride)
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
pipe.vae_decoder.set_window_size(latent_window, params.overlap)
|
||||
if hasattr(pipe, "vae_encoder"):
|
||||
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
|
||||
|
||||
run_gc([device])
|
||||
|
||||
return pipe
|
||||
return components
|
||||
|
||||
|
||||
def optimize_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
) -> None:
|
||||
if (
|
||||
"diffusers-attention-slicing" in server.optimizations
|
||||
or "diffusers-attention-slicing-auto" in server.optimizations
|
||||
):
|
||||
if server.has_optimization(
|
||||
"diffusers-attention-slicing"
|
||||
) or server.has_optimization("diffusers-attention-slicing-auto"):
|
||||
logger.debug("enabling auto attention slicing on SD pipeline")
|
||||
try:
|
||||
pipe.enable_attention_slicing(slice_size="auto")
|
||||
except Exception as e:
|
||||
logger.warning("error while enabling auto attention slicing: %s", e)
|
||||
|
||||
if "diffusers-attention-slicing-max" in server.optimizations:
|
||||
if server.has_optimization("diffusers-attention-slicing-max"):
|
||||
logger.debug("enabling max attention slicing on SD pipeline")
|
||||
try:
|
||||
pipe.enable_attention_slicing(slice_size="max")
|
||||
except Exception as e:
|
||||
logger.warning("error while enabling max attention slicing: %s", e)
|
||||
|
||||
if "diffusers-vae-slicing" in server.optimizations:
|
||||
if server.has_optimization("diffusers-vae-slicing"):
|
||||
logger.debug("enabling VAE slicing on SD pipeline")
|
||||
try:
|
||||
pipe.enable_vae_slicing()
|
||||
except Exception as e:
|
||||
logger.warning("error while enabling VAE slicing: %s", e)
|
||||
|
||||
if "diffusers-cpu-offload-sequential" in server.optimizations:
|
||||
if server.has_optimization("diffusers-cpu-offload-sequential"):
|
||||
logger.debug("enabling sequential CPU offload on SD pipeline")
|
||||
try:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
except Exception as e:
|
||||
logger.warning("error while enabling sequential CPU offload: %s", e)
|
||||
|
||||
elif "diffusers-cpu-offload-model" in server.optimizations:
|
||||
elif server.has_optimization("diffusers-cpu-offload-model"):
|
||||
# TODO: check for accelerate
|
||||
logger.debug("enabling model CPU offload on SD pipeline")
|
||||
try:
|
||||
|
@ -502,7 +610,7 @@ def optimize_pipeline(
|
|||
except Exception as e:
|
||||
logger.warning("error while enabling model CPU offload: %s", e)
|
||||
|
||||
if "diffusers-memory-efficient-attention" in server.optimizations:
|
||||
if server.has_optimization("diffusers-memory-efficient-attention"):
|
||||
# TODO: check for xformers
|
||||
logger.debug("enabling memory efficient attention for SD pipeline")
|
||||
try:
|
||||
|
@ -514,17 +622,17 @@ def optimize_pipeline(
|
|||
def patch_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
pipe_type: str,
|
||||
pipeline: Any,
|
||||
params: ImageParams,
|
||||
) -> None:
|
||||
logger.debug("patching SD pipeline")
|
||||
|
||||
if pipe_type != "lpw":
|
||||
if not params.is_lpw() and not params.is_xl():
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
||||
original_unet = pipe.unet
|
||||
pipe.unet = UNetWrapper(server, original_unet)
|
||||
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
|
||||
logger.debug("patched UNet with wrapper")
|
||||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
original_decoder = pipe.vae_decoder
|
||||
|
@ -532,18 +640,21 @@ def patch_pipeline(
|
|||
server,
|
||||
original_decoder,
|
||||
decoder=True,
|
||||
window=params.tiles,
|
||||
overlap=params.overlap,
|
||||
window=params.unet_tile,
|
||||
overlap=params.vae_overlap,
|
||||
)
|
||||
logger.debug("patched VAE decoder with wrapper")
|
||||
|
||||
if hasattr(pipe, "vae_encoder"):
|
||||
original_encoder = pipe.vae_encoder
|
||||
pipe.vae_encoder = VAEWrapper(
|
||||
server,
|
||||
original_encoder,
|
||||
decoder=False,
|
||||
window=params.tiles,
|
||||
overlap=params.overlap,
|
||||
window=params.unet_tile,
|
||||
overlap=params.vae_overlap,
|
||||
)
|
||||
elif hasattr(pipe, "vae"):
|
||||
pass # TODO: current wrapper does not work with upscaling VAE
|
||||
else:
|
||||
logger.debug("no VAE found to patch")
|
||||
logger.debug("patched VAE encoder with wrapper")
|
||||
|
||||
if hasattr(pipe, "vae"):
|
||||
logger.warning("not patching single VAE, tiled VAE may not work")
|
||||
|
|
|
@ -14,20 +14,23 @@ class UNetWrapper(object):
|
|||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
wrapped: OnnxRuntimeModel
|
||||
xl: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
xl: bool,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.xl = xl
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: np.ndarray = None,
|
||||
timestep: np.ndarray = None,
|
||||
encoder_hidden_states: np.ndarray = None,
|
||||
sample: Optional[np.ndarray] = None,
|
||||
timestep: Optional[np.ndarray] = None,
|
||||
encoder_hidden_states: Optional[np.ndarray] = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.trace(
|
||||
|
@ -43,6 +46,14 @@ class UNetWrapper(object):
|
|||
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||
self.prompt_index += 1
|
||||
|
||||
if self.xl:
|
||||
if sample.dtype != encoder_hidden_states.dtype:
|
||||
logger.trace(
|
||||
"converting UNet sample to hidden state dtype for XL: %s",
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
sample = sample.astype(encoder_hidden_states.dtype)
|
||||
else:
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
|
|
@ -12,8 +12,6 @@ from ...server import ServerContext
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
LATENT_CHANNELS = 4
|
||||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(
|
||||
|
@ -39,11 +37,17 @@ class VAEWrapper(object):
|
|||
self.tile_overlap_factor = overlap
|
||||
|
||||
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
||||
model = (
|
||||
self.wrapped.model
|
||||
if hasattr(self.wrapped, "model")
|
||||
else self.wrapped.session
|
||||
)
|
||||
|
||||
# set timestep dtype to input type
|
||||
sample_dtype = next(
|
||||
(
|
||||
input.type
|
||||
for input in self.wrapped.model.get_inputs()
|
||||
for input in model.get_inputs()
|
||||
if input.name == "sample" or input.name == "latent_sample"
|
||||
),
|
||||
"tensor(float)",
|
||||
|
|
|
@ -13,8 +13,8 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
|
|
|
@ -13,25 +13,36 @@
|
|||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
from math import ceil
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from ...chain.tile import make_tile_mask
|
||||
from ...constants import LATENT_CHANNELS, LATENT_FACTOR
|
||||
from ...params import Size
|
||||
from ..utils import (
|
||||
expand_latents,
|
||||
parse_regions,
|
||||
random_seed,
|
||||
repair_nan,
|
||||
resize_latent_shape,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# inpaint constants
|
||||
NUM_UNET_INPUT_CHANNELS = 9
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
|
||||
DEFAULT_WINDOW = 32
|
||||
DEFAULT_STRIDE = 8
|
||||
|
@ -346,13 +357,16 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def get_views(self, panorama_height, panorama_width, window_size, stride):
|
||||
def get_views(
|
||||
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
|
||||
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
|
||||
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
|
||||
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
|
||||
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
|
||||
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
|
||||
|
||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||
logger.debug(
|
||||
"panorama generated %s views, %s by %s blocks",
|
||||
|
@ -369,7 +383,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
w_end = w_start + window_size
|
||||
views.append((h_start, h_end, w_start, w_end))
|
||||
|
||||
return views
|
||||
return (views, (h_end * 8, w_end * 8))
|
||||
|
||||
@torch.no_grad()
|
||||
def text2img(
|
||||
|
@ -479,6 +493,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt, regions = parse_regions(prompt)
|
||||
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
|
@ -488,9 +504,30 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 3.b. Encode region prompts
|
||||
region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
|
||||
if region_prompt.endswith("+"):
|
||||
region_prompt = region_prompt[:-1] + " " + prompt
|
||||
|
||||
region_prompt_embeds = self._encode_prompt(
|
||||
region_prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
region_embeds.append(region_prompt_embeds)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
LATENT_CHANNELS,
|
||||
height // LATENT_FACTOR,
|
||||
width // LATENT_FACTOR,
|
||||
)
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
elif latents.shape != latents_shape:
|
||||
|
@ -525,11 +562,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
logger.trace("panorama resized latents to %s", resize)
|
||||
|
||||
count = np.zeros(resize_latent_shape(latents, resize))
|
||||
value = np.zeros(resize_latent_shape(latents, resize))
|
||||
|
||||
# adjust latents
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
random_seed(generator),
|
||||
Size(resize[1], resize[0]),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
last = i == (len(self.scheduler.timesteps) - 1)
|
||||
count.fill(0)
|
||||
value.fill(0)
|
||||
|
||||
|
@ -576,13 +624,115 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
if not last:
|
||||
for r, region in enumerate(regions):
|
||||
top, left, bottom, right, weight, feather, prompt = region
|
||||
logger.debug(
|
||||
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||
top,
|
||||
left,
|
||||
bottom,
|
||||
right,
|
||||
weight,
|
||||
feather,
|
||||
prompt,
|
||||
)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // LATENT_FACTOR
|
||||
h_end = bottom // LATENT_FACTOR
|
||||
w_start = left // LATENT_FACTOR
|
||||
w_end = right // LATENT_FACTOR
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
logger.trace(
|
||||
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||
h_start,
|
||||
h_end,
|
||||
w_start,
|
||||
w_end,
|
||||
latents_for_region.shape,
|
||||
)
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_region_input = (
|
||||
np.concatenate([latents_for_region] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_region
|
||||
)
|
||||
latent_region_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_region_input), t
|
||||
)
|
||||
latent_region_input = latent_region_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
region_noise_pred = self.unet(
|
||||
sample=latent_region_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=region_embeds[r],
|
||||
)
|
||||
region_noise_pred = region_noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||
region_noise_pred, 2
|
||||
)
|
||||
region_noise_pred = (
|
||||
region_noise_pred_uncond
|
||||
+ guidance_scale
|
||||
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(region_noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_region),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if feather[0] > 0.0:
|
||||
mask = make_tile_mask(
|
||||
(h_end - h_start, w_end - w_start),
|
||||
(h_end - h_start, w_end - w_start),
|
||||
feather[0],
|
||||
feather[1],
|
||||
)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
mask = np.repeat(mask, 4, axis=0)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
else:
|
||||
mask = 1
|
||||
|
||||
if weight >= 100.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||
latents_region_denoised * mask
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||
latents_region_denoised * weight * mask
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
latents = repair_nan(latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[
|
||||
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||
]
|
||||
|
||||
latents = np.clip(latents, -4, +4)
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
|
@ -828,9 +978,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
logger.trace("panorama resized latents to %s", resize)
|
||||
|
||||
count = np.zeros(resize_latent_shape(latents, resize))
|
||||
value = np.zeros(resize_latent_shape(latents, resize))
|
||||
|
||||
# adjust latents
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
random_seed(generator),
|
||||
Size(resize[1], resize[0]),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
count.fill(0)
|
||||
|
@ -886,6 +1046,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[
|
||||
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||
]
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
|
@ -1053,12 +1218,12 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
num_channels_latents = NUM_LATENT_CHANNELS
|
||||
num_channels_latents = LATENT_CHANNELS
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height // 8,
|
||||
width // 8,
|
||||
height // LATENT_FACTOR,
|
||||
width // LATENT_FACTOR,
|
||||
)
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
if latents is None:
|
||||
|
@ -1136,9 +1301,19 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width, self.window, self.stride)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
logger.trace("panorama resized latents to %s", resize)
|
||||
|
||||
count = np.zeros(resize_latent_shape(latents, resize))
|
||||
value = np.zeros(resize_latent_shape(latents, resize))
|
||||
|
||||
# adjust latents
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
random_seed(generator),
|
||||
Size(resize[1], resize[0]),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
count.fill(0)
|
||||
|
@ -1201,6 +1376,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[
|
||||
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||
]
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
|
|
|
@ -0,0 +1,955 @@
|
|||
import inspect
|
||||
import logging
|
||||
from math import ceil
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTStableDiffusionXLPipelineBase
|
||||
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
||||
StableDiffusionXLImg2ImgPipelineMixin,
|
||||
)
|
||||
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
|
||||
|
||||
from ...chain.tile import make_tile_mask
|
||||
from ...constants import LATENT_FACTOR
|
||||
from ...params import Size
|
||||
from ..utils import (
|
||||
expand_latents,
|
||||
parse_regions,
|
||||
random_seed,
|
||||
repair_nan,
|
||||
resize_latent_shape,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_WINDOW = 64
|
||||
DEFAULT_STRIDE = 16
|
||||
|
||||
|
||||
class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
window: int = DEFAULT_WINDOW,
|
||||
stride: int = DEFAULT_STRIDE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(self, *args, **kwargs)
|
||||
|
||||
self.window = window
|
||||
self.stride = stride
|
||||
|
||||
def set_window_size(self, window: int, stride: int):
|
||||
self.window = window
|
||||
self.stride = stride
|
||||
|
||||
def get_views(
|
||||
self, panorama_height: int, panorama_width: int, window_size: int, stride: int
|
||||
) -> Tuple[List[Tuple[int, int, int, int]], Tuple[int, int]]:
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
|
||||
num_blocks_height = ceil(abs((panorama_height - window_size) / stride)) + 1
|
||||
num_blocks_width = ceil(abs((panorama_width - window_size) / stride)) + 1
|
||||
|
||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||
logger.debug(
|
||||
"panorama generated %s views, %s by %s blocks",
|
||||
total_num_blocks,
|
||||
num_blocks_height,
|
||||
num_blocks_width,
|
||||
)
|
||||
|
||||
views = []
|
||||
for i in range(total_num_blocks):
|
||||
h_start = int((i // num_blocks_width) * stride)
|
||||
h_end = h_start + window_size
|
||||
w_start = int((i % num_blocks_width) * stride)
|
||||
w_end = w_start + window_size
|
||||
views.append((h_start, h_end, w_start, w_end))
|
||||
|
||||
return (views, (h_end * 8, w_end * 8))
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents_img2img(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None
|
||||
):
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
else:
|
||||
init_latents = self.vae_encoder(sample=image)[
|
||||
0
|
||||
] * self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||
|
||||
if (
|
||||
batch_size > init_latents.shape[0]
|
||||
and batch_size % init_latents.shape[0] == 0
|
||||
):
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = np.concatenate(
|
||||
[init_latents] * additional_image_per_prompt, axis=0
|
||||
)
|
||||
elif (
|
||||
batch_size > init_latents.shape[0]
|
||||
and batch_size % init_latents.shape[0] != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = np.concatenate([init_latents], axis=0)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = generator.randn(*init_latents.shape).astype(dtype)
|
||||
init_latents = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents),
|
||||
torch.from_numpy(noise),
|
||||
torch.from_numpy(timestep),
|
||||
)
|
||||
return init_latents.numpy()
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents_text2img(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = generator.randn(*shape).astype(dtype)
|
||||
elif latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
return latents
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
extra_step_kwargs = {}
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
return extra_step_kwargs
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||
def text2img(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`Optional[int]`, defaults to None):
|
||||
The height in pixels of the generated image.
|
||||
width (`Optional[int]`, defaults to None):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to 5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`Optional[Union[str, list]]`):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||
is less than `1`).
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
|
||||
A np.random.RandomState to make generation deterministic.
|
||||
latents (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (Optional[Callable], defaults to `None`):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
guidance_rescale (`float`, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
|
||||
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
1.0,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt, regions = parse_regions(prompt)
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 3.b. Encode region prompts
|
||||
region_embeds: List[np.ndarray] = []
|
||||
add_region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
|
||||
if region_prompt.endswith("+"):
|
||||
region_prompt = region_prompt[:-1] + " " + prompt
|
||||
|
||||
(
|
||||
region_prompt_embeds,
|
||||
region_negative_prompt_embeds,
|
||||
region_pooled_prompt_embeds,
|
||||
region_negative_pooled_prompt_embeds,
|
||||
) = self._encode_prompt(
|
||||
region_prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
region_prompt_embeds = np.concatenate(
|
||||
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
|
||||
)
|
||||
add_region_embeds.append(
|
||||
np.concatenate(
|
||||
(
|
||||
region_negative_pooled_prompt_embeds,
|
||||
region_pooled_prompt_embeds,
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
)
|
||||
|
||||
region_embeds.append(region_prompt_embeds)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
latents = self.prepare_latents_text2img(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.unet.config.get("in_channels", 4),
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = (original_size + crops_coords_top_left + target_size,)
|
||||
add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = np.concatenate(
|
||||
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||
)
|
||||
add_text_embeds = np.concatenate(
|
||||
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||
)
|
||||
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||
|
||||
add_time_ids = np.repeat(
|
||||
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||
)
|
||||
|
||||
# Adapted from diffusers to extend it for other runtimes than ORT
|
||||
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||
|
||||
# 8. Panorama additions
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
logger.trace("panorama resized latents to %s", resize)
|
||||
|
||||
count = np.zeros(resize_latent_shape(latents, resize))
|
||||
value = np.zeros(resize_latent_shape(latents, resize))
|
||||
|
||||
# adjust latents
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
random_seed(generator),
|
||||
Size(resize[1], resize[0]),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
last = i == (len(timesteps) - 1)
|
||||
count.fill(0)
|
||||
value.fill(0)
|
||||
|
||||
for h_start, h_end, w_start, w_end in views:
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
np.concatenate([latents_for_view] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_view
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_model_input), t
|
||||
)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
text_embeds=add_text_embeds,
|
||||
time_ids=add_time_ids,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(
|
||||
noise_pred,
|
||||
noise_pred_text,
|
||||
guidance_rescale=guidance_rescale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_view),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
if not last:
|
||||
for r, region in enumerate(regions):
|
||||
top, left, bottom, right, weight, feather, prompt = region
|
||||
logger.debug(
|
||||
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||
top,
|
||||
left,
|
||||
bottom,
|
||||
right,
|
||||
weight,
|
||||
feather,
|
||||
prompt,
|
||||
)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // LATENT_FACTOR
|
||||
h_end = bottom // LATENT_FACTOR
|
||||
w_start = left // LATENT_FACTOR
|
||||
w_end = right // LATENT_FACTOR
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
logger.trace(
|
||||
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||
h_start,
|
||||
h_end,
|
||||
w_start,
|
||||
w_end,
|
||||
latents_for_region.shape,
|
||||
)
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_region_input = (
|
||||
np.concatenate([latents_for_region] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_region
|
||||
)
|
||||
latent_region_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_region_input), t
|
||||
)
|
||||
latent_region_input = latent_region_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
region_noise_pred = self.unet(
|
||||
sample=latent_region_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=region_embeds[r],
|
||||
text_embeds=add_region_embeds[r],
|
||||
time_ids=add_time_ids,
|
||||
)
|
||||
region_noise_pred = region_noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||
region_noise_pred, 2
|
||||
)
|
||||
region_noise_pred = (
|
||||
region_noise_pred_uncond
|
||||
+ guidance_scale
|
||||
* (region_noise_pred_text - region_noise_pred_uncond)
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
region_noise_pred = rescale_noise_cfg(
|
||||
region_noise_pred,
|
||||
region_noise_pred_text,
|
||||
guidance_rescale=guidance_rescale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(region_noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_region),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if feather[0] > 0.0:
|
||||
mask = make_tile_mask(
|
||||
(h_end - h_start, w_end - w_start),
|
||||
(h_end - h_start, w_end - w_start),
|
||||
feather[0],
|
||||
feather[1],
|
||||
)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
mask = np.repeat(mask, 4, axis=0)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
else:
|
||||
mask = 1
|
||||
|
||||
if weight >= 100.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||
latents_region_denoised * mask
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||
latents_region_denoised * weight * mask
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
latents = repair_nan(latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[
|
||||
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||
]
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = np.clip(latents, -4, +4)
|
||||
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[
|
||||
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||
for i in range(latents.shape[0])
|
||||
]
|
||||
)
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
# TODO: add image_processor
|
||||
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__
|
||||
def img2img(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
strength: float = 0.3,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_pooled_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`Optional[Union[str, List[str]]]`, defaults to None):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`Union[np.ndarray, PIL.Image.Image]`):
|
||||
`Image`, or tensor representing an image batch which will be upscaled.
|
||||
strength (`float`, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to 5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`Optional[Union[str, list]]`):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||
is less than `1`).
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`Optional[np.random.RandomState]`, defaults to `None`)::
|
||||
A np.random.RandomState to make generation deterministic.
|
||||
latents (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (Optional[Callable], defaults to `None`):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
guidance_rescale (`float`, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 1. Define call parameters
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 2. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
processor = VaeImageProcessor()
|
||||
image = processor.preprocess(image).cpu().numpy()
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps, strength
|
||||
)
|
||||
latent_timestep = np.repeat(
|
||||
timesteps[:1], batch_size * num_images_per_prompt, axis=0
|
||||
)
|
||||
timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)
|
||||
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
image = image.astype(latents_dtype)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
latents = self.prepare_latents_img2img(
|
||||
image,
|
||||
latent_timestep,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
latents_dtype,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = {}
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 8. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = np.concatenate(
|
||||
(negative_prompt_embeds, prompt_embeds), axis=0
|
||||
)
|
||||
add_text_embeds = np.concatenate(
|
||||
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||
)
|
||||
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||
add_time_ids = np.repeat(
|
||||
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||
)
|
||||
|
||||
# 8. Panorama additions
|
||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||
logger.trace("panorama resized latents to %s", resize)
|
||||
|
||||
count = np.zeros(resize_latent_shape(latents, resize))
|
||||
value = np.zeros(resize_latent_shape(latents, resize))
|
||||
|
||||
latents = expand_latents(
|
||||
latents,
|
||||
random_seed(generator),
|
||||
Size(resize[1], resize[0]),
|
||||
sigma=self.scheduler.init_noise_sigma,
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
count.fill(0)
|
||||
value.fill(0)
|
||||
|
||||
for h_start, h_end, w_start, w_end in views:
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
np.concatenate([latents_for_view] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_view
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_model_input), t
|
||||
)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
text_embeds=add_text_embeds,
|
||||
time_ids=add_time_ids,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(
|
||||
noise_pred,
|
||||
noise_pred_text,
|
||||
guidance_rescale=guidance_rescale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_view),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# remove extra margins
|
||||
latents = latents[
|
||||
:, :, 0 : (height // LATENT_FACTOR), 0 : (width // LATENT_FACTOR)
|
||||
]
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[
|
||||
self.vae_decoder(latent_sample=latents[i : i + 1])[0]
|
||||
for i in range(latents.shape[0])
|
||||
]
|
||||
)
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
# TODO: add image_processor
|
||||
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if "image" in kwargs or (
|
||||
len(args) > 1
|
||||
and (
|
||||
isinstance(args[1], np.ndarray) or isinstance(args[1], PIL.Image.Image)
|
||||
)
|
||||
):
|
||||
logger.debug("running img2img panorama XL pipeline")
|
||||
return self.img2img(*args, **kwargs)
|
||||
else:
|
||||
logger.debug("running txt2img panorama XL pipeline")
|
||||
return self.text2img(*args, **kwargs)
|
||||
|
||||
|
||||
class ORTStableDiffusionXLPanoramaPipeline(
|
||||
ORTStableDiffusionXLPipelineBase, StableDiffusionXLPanoramaPipelineMixin
|
||||
):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return StableDiffusionXLPanoramaPipelineMixin.__call__(self, *args, **kwargs)
|
|
@ -32,7 +32,7 @@ except ImportError:
|
|||
}
|
||||
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
|
|
|
@ -1,63 +1,25 @@
|
|||
###
|
||||
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
|
||||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
|
||||
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
|
||||
# See also: https://github.com/huggingface/diffusers/pull/2158
|
||||
###
|
||||
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
||||
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
OnnxStableDiffusionUpscalePipeline as BasePipeline,
|
||||
)
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class FakeConfig:
|
||||
block_out_channels: List[int]
|
||||
scaling_factor: float
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [128, 256, 512]
|
||||
self.scaling_factor = 0.08333
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
|
@ -80,260 +42,3 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
scheduler,
|
||||
max_noise_level=max_noise_level,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt,
|
||||
# device, device only needed for Torch pipelines
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.cpu()
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(
|
||||
image.shape, generator=generator, device=device, dtype=latents_dtype
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = np.concatenate([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||
raise ValueError(
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
timestep_dtype = next(
|
||||
(
|
||||
input.type
|
||||
for input in self.unet.model.get_inputs()
|
||||
if input.name == "timestep"
|
||||
),
|
||||
"tensor(float)",
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
np.concatenate([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||
|
||||
# timestep to tensor
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level.astype(np.int64),
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# no positional arguments to text_encoder
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
|
||||
text_embeddings = text_embeddings.reshape(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.reshape(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
|
|
@ -4,15 +4,16 @@ from typing import Any, List, Optional
|
|||
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from onnx_web.chain.highres import stage_highres
|
||||
|
||||
from ..chain import (
|
||||
BlendDenoiseStage,
|
||||
BlendImg2ImgStage,
|
||||
BlendMaskStage,
|
||||
ChainPipeline,
|
||||
SourceTxt2ImgStage,
|
||||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.highres import stage_highres
|
||||
from ..chain.result import StageResult
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image
|
||||
|
@ -33,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_base_tile(params: ImageParams, size: Size) -> int:
|
||||
if params.is_panorama():
|
||||
tile = max(params.unet_tile, size.width, size.height)
|
||||
logger.debug("adjusting tile size for panorama to %s", tile)
|
||||
return tile
|
||||
|
||||
return params.unet_tile
|
||||
|
||||
|
||||
def get_highres_tile(
|
||||
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
|
||||
) -> int:
|
||||
if params.is_panorama() and server.has_feature("panorama-highres"):
|
||||
return tile * highres.scale
|
||||
|
||||
return params.unet_tile
|
||||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
|
@ -43,10 +62,7 @@ def run_txt2img_pipeline(
|
|||
highres: HighresParams,
|
||||
) -> None:
|
||||
# if using panorama, the pipeline will tile itself (views)
|
||||
if params.is_panorama() or params.is_xl():
|
||||
tile_size = max(params.tiles, size.width, size.height)
|
||||
else:
|
||||
tile_size = params.tiles
|
||||
tile_size = get_base_tile(params, size)
|
||||
|
||||
# prepare the chain pipeline and first stage
|
||||
chain = ChainPipeline()
|
||||
|
@ -57,15 +73,21 @@ def run_txt2img_pipeline(
|
|||
),
|
||||
size=size,
|
||||
prompt_index=0,
|
||||
overlap=params.overlap,
|
||||
overlap=params.vae_overlap,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
stage = StageParams(tile_size=params.tiles)
|
||||
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||
if params.is_panorama():
|
||||
chain.stage(
|
||||
BlendDenoiseStage(),
|
||||
StageParams(tile_size=highres_size),
|
||||
)
|
||||
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=first_upscale.outscale, tile_size=highres_size),
|
||||
params,
|
||||
chain=chain,
|
||||
upscale=first_upscale,
|
||||
|
@ -73,7 +95,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply highres
|
||||
stage_highres(
|
||||
stage,
|
||||
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
|
@ -83,7 +105,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=after_upscale.outscale, tile_size=highres_size),
|
||||
params,
|
||||
chain=chain,
|
||||
upscale=after_upscale,
|
||||
|
@ -92,11 +114,14 @@ def run_txt2img_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
logger.trace("saving output image %s: %s", output, image.size)
|
||||
dest = save_image(
|
||||
server,
|
||||
output,
|
||||
|
@ -136,23 +161,26 @@ def run_img2img_pipeline(
|
|||
source = f(server, source)
|
||||
|
||||
# prepare the chain pipeline and first stage
|
||||
tile_size = get_base_tile(params, Size(*source.size))
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams(
|
||||
tile_size=params.tiles,
|
||||
)
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
StageParams(
|
||||
tile_size=tile_size,
|
||||
),
|
||||
prompt_index=0,
|
||||
strength=strength,
|
||||
overlap=params.overlap,
|
||||
overlap=params.vae_overlap,
|
||||
)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(
|
||||
outscale=first_upscale.outscale,
|
||||
tile_size=tile_size,
|
||||
),
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
chain=chain,
|
||||
|
@ -162,13 +190,16 @@ def run_img2img_pipeline(
|
|||
for _i in range(params.loopback):
|
||||
chain.stage(
|
||||
BlendImg2ImgStage(),
|
||||
stage,
|
||||
StageParams(
|
||||
tile_size=tile_size,
|
||||
),
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
# highres, if selected
|
||||
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||
stage_highres(
|
||||
stage,
|
||||
StageParams(tile_size=highres_size, outscale=highres.scale),
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
|
@ -178,7 +209,7 @@ def run_img2img_pipeline(
|
|||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(tile_size=tile_size, outscale=after_upscale.scale),
|
||||
params,
|
||||
upscale=after_upscale,
|
||||
chain=chain,
|
||||
|
@ -186,7 +217,9 @@ def run_img2img_pipeline(
|
|||
|
||||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
images.append(source)
|
||||
|
@ -235,7 +268,7 @@ def run_inpaint_pipeline(
|
|||
full_res_inpaint_padding: float,
|
||||
) -> None:
|
||||
logger.debug("building inpaint pipeline")
|
||||
tile_size = params.tiles
|
||||
tile_size = get_base_tile(params, size)
|
||||
|
||||
if mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
|
@ -264,8 +297,12 @@ def run_inpaint_pipeline(
|
|||
logger.debug("border zero: %s", border.isZero())
|
||||
full_res_inpaint = full_res_inpaint and border.isZero()
|
||||
if full_res_inpaint:
|
||||
mask_left, mask_top, mask_right, mask_bottom = mask.getbbox()
|
||||
logger.debug("mask bbox: %s", mask.getbbox())
|
||||
bbox = mask.getbbox()
|
||||
if bbox is None:
|
||||
bbox = (0, 0, source.width, source.height)
|
||||
|
||||
logger.debug("mask bounding box: %s", bbox)
|
||||
mask_left, mask_top, mask_right, mask_bottom = bbox
|
||||
mask_width = mask_right - mask_left
|
||||
mask_height = mask_bottom - mask_top
|
||||
# ensure we have some padding around the mask when we do the inpaint (and that the region size is even)
|
||||
|
@ -322,16 +359,15 @@ def run_inpaint_pipeline(
|
|||
|
||||
# set up the chain pipeline and base stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
|
||||
chain.stage(
|
||||
UpscaleOutpaintStage(),
|
||||
stage,
|
||||
StageParams(tile_order=tile_order, tile_size=tile_size),
|
||||
border=border,
|
||||
mask=mask,
|
||||
fill_color=fill_color,
|
||||
mask_filter=mask_filter,
|
||||
noise_source=noise_source,
|
||||
overlap=params.overlap,
|
||||
overlap=params.vae_overlap,
|
||||
prompt_index=0,
|
||||
)
|
||||
|
||||
|
@ -339,15 +375,16 @@ def run_inpaint_pipeline(
|
|||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply highres
|
||||
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||
stage_highres(
|
||||
stage,
|
||||
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
|
@ -357,7 +394,7 @@ def run_inpaint_pipeline(
|
|||
|
||||
# apply upscaling and correction
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=after_upscale.outscale),
|
||||
params,
|
||||
upscale=after_upscale,
|
||||
chain=chain,
|
||||
|
@ -366,7 +403,14 @@ def run_inpaint_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress, latents=latents)
|
||||
images = chain.run(
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source]),
|
||||
callback=progress,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -409,21 +453,22 @@ def run_upscale_pipeline(
|
|||
) -> None:
|
||||
# set up the chain pipeline, no base stage for upscaling
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams(tile_size=params.tiles)
|
||||
tile_size = get_base_tile(params, size)
|
||||
|
||||
# apply upscaling and correction, before highres
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
if first_upscale:
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||
params,
|
||||
upscale=first_upscale,
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
# apply highres
|
||||
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||
stage_highres(
|
||||
stage,
|
||||
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||
params,
|
||||
highres,
|
||||
upscale,
|
||||
|
@ -433,7 +478,7 @@ def run_upscale_pipeline(
|
|||
|
||||
# apply upscaling and correction, after highres
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
|
||||
params,
|
||||
upscale=after_upscale,
|
||||
chain=chain,
|
||||
|
@ -441,7 +486,9 @@ def run_upscale_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, [source], callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -478,12 +525,18 @@ def run_blend_pipeline(
|
|||
) -> None:
|
||||
# set up the chain pipeline and base stage
|
||||
chain = ChainPipeline()
|
||||
stage = StageParams()
|
||||
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
|
||||
tile_size = get_base_tile(params, size)
|
||||
|
||||
chain.stage(
|
||||
BlendMaskStage(),
|
||||
StageParams(tile_size=tile_size),
|
||||
stage_source=sources[1],
|
||||
stage_mask=mask,
|
||||
)
|
||||
|
||||
# apply upscaling and correction
|
||||
stage_upscale_correction(
|
||||
stage,
|
||||
StageParams(outscale=upscale.outscale),
|
||||
params,
|
||||
upscale=upscale,
|
||||
chain=chain,
|
||||
|
@ -491,7 +544,9 @@ def run_blend_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain(worker, server, params, sources, callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=sources), callback=progress
|
||||
)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||
|
|
|
@ -3,23 +3,27 @@ from copy import deepcopy
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from re import Pattern, compile
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
from ..constants import LATENT_CHANNELS, LATENT_FACTOR
|
||||
from ..params import ImageParams, Size
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
LATENT_CHANNELS = 4
|
||||
LATENT_FACTOR = 8
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
|
||||
ANY_TOKEN = compile(r"\<([^\>]*)\>")
|
||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
REGION_TOKEN = compile(
|
||||
r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):(-?[\.|\d]+_?[TLBR]*):([^\>]+)\>"
|
||||
)
|
||||
RESEED_TOKEN = compile(r"\<reseed:(\d+):(\d+):(\d+):(\d+):(-?\d+)\>")
|
||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
|
@ -84,8 +88,8 @@ def expand_prompt(
|
|||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
skip_clip_states: Optional[int] = 0,
|
||||
) -> "np.NDArray":
|
||||
skip_clip_states: int = 0,
|
||||
) -> np.ndarray:
|
||||
# self provides:
|
||||
# tokenizer: CLIPTokenizer
|
||||
# encoder: OnnxRuntimeModel
|
||||
|
@ -140,6 +144,7 @@ def expand_prompt(
|
|||
|
||||
last_state, _pooled_output, *hidden_states = text_result
|
||||
if skip_clip_states > 0:
|
||||
# TODO: why is this normalized?
|
||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||
norm_state = layer_norm(
|
||||
torch.from_numpy(
|
||||
|
@ -219,20 +224,25 @@ def expand_prompt(
|
|||
return prompt_embeds
|
||||
|
||||
|
||||
def parse_float_group(group: Tuple[str, str]) -> Tuple[str, float]:
|
||||
name, weight = group
|
||||
return (name, float(weight))
|
||||
|
||||
|
||||
def get_tokens_from_prompt(
|
||||
prompt: str, pattern: Pattern
|
||||
prompt: str,
|
||||
pattern: Pattern,
|
||||
parser=parse_float_group,
|
||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
TODO: replace with Arpeggio
|
||||
"""
|
||||
remaining_prompt = prompt
|
||||
|
||||
tokens = []
|
||||
next_match = pattern.search(remaining_prompt)
|
||||
while next_match is not None:
|
||||
logger.debug("found token in prompt: %s", next_match)
|
||||
name, weight = next_match.groups()
|
||||
tokens.append((name, float(weight)))
|
||||
group = next_match.groups()
|
||||
tokens.append(parser(group))
|
||||
|
||||
# remove this match and look for another
|
||||
remaining_prompt = (
|
||||
remaining_prompt[: next_match.start()]
|
||||
|
@ -251,6 +261,13 @@ def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]
|
|||
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
||||
|
||||
|
||||
def random_seed(generator=None) -> int:
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
return generator.randint(np.iinfo(np.int32).max)
|
||||
|
||||
|
||||
def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||
"""
|
||||
From https://www.travelneil.com/stable-diffusion-updates.html.
|
||||
|
@ -266,6 +283,25 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
|||
return image_latents
|
||||
|
||||
|
||||
def expand_latents(
|
||||
latents: np.ndarray,
|
||||
seed: int,
|
||||
size: Size,
|
||||
sigma: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
batch, _channels, height, width = latents.shape
|
||||
extra_latents = get_latents_from_seed(seed, size, batch=batch)
|
||||
extra_latents[:, :, 0:height, 0:width] = latents
|
||||
return extra_latents * np.float64(sigma)
|
||||
|
||||
|
||||
def resize_latent_shape(
|
||||
latents: np.ndarray,
|
||||
size: Tuple[int, int],
|
||||
) -> Tuple[int, int, int, int]:
|
||||
return (latents.shape[0], latents.shape[1], *size)
|
||||
|
||||
|
||||
def get_tile_latents(
|
||||
full_latents: np.ndarray,
|
||||
seed: int,
|
||||
|
@ -290,14 +326,8 @@ def get_tile_latents(
|
|||
|
||||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
if tile_latents.shape != full_latents.shape and (
|
||||
tile_latents.shape[2] < t or tile_latents.shape[3] < t
|
||||
):
|
||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
||||
extra_latents[
|
||||
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
||||
] = tile_latents
|
||||
tile_latents = extra_latents
|
||||
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||
tile_latents = expand_latents(tile_latents, seed, size)
|
||||
|
||||
return tile_latents
|
||||
|
||||
|
@ -369,12 +399,15 @@ def encode_prompt(
|
|||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
TODO: does not work with SDXL, fix or turn into a pipeline patch
|
||||
"""
|
||||
return [
|
||||
pipe._encode_prompt(
|
||||
prompt,
|
||||
remove_tokens(prompt),
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=neg_prompt,
|
||||
negative_prompt=remove_tokens(neg_prompt),
|
||||
)
|
||||
for prompt, neg_prompt in prompt_pairs
|
||||
]
|
||||
|
@ -444,3 +477,71 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
|||
return parts[min(slice, len(parts) - 1)]
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
Region = Tuple[
|
||||
int, int, int, int, float, Tuple[float, Tuple[bool, bool, bool, bool]], str
|
||||
]
|
||||
|
||||
|
||||
def parse_region_group(group: Tuple[str, ...]) -> Region:
|
||||
top, left, bottom, right, weight, feather, prompt = group
|
||||
|
||||
# break down the feather section
|
||||
feather_radius, *feather_edges = feather.split("_")
|
||||
if len(feather_edges) == 0:
|
||||
feather_edges = "TLBR"
|
||||
else:
|
||||
feather_edges = "".join(feather_edges)
|
||||
|
||||
return (
|
||||
int(top),
|
||||
int(left),
|
||||
int(bottom),
|
||||
int(right),
|
||||
float(weight),
|
||||
(
|
||||
float(feather_radius),
|
||||
(
|
||||
"T" in feather_edges,
|
||||
"L" in feather_edges,
|
||||
"B" in feather_edges,
|
||||
"R" in feather_edges,
|
||||
),
|
||||
),
|
||||
prompt,
|
||||
)
|
||||
|
||||
|
||||
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||
return get_tokens_from_prompt(prompt, REGION_TOKEN, parser=parse_region_group)
|
||||
|
||||
|
||||
Reseed = Tuple[int, int, int, int, int]
|
||||
|
||||
|
||||
def parse_reseed_group(group) -> Region:
|
||||
top, left, bottom, right, seed = group
|
||||
return (
|
||||
int(top),
|
||||
int(left),
|
||||
int(bottom),
|
||||
int(right),
|
||||
int(seed),
|
||||
)
|
||||
|
||||
|
||||
def parse_reseed(prompt: str) -> Tuple[str, List[Reseed]]:
|
||||
return get_tokens_from_prompt(prompt, RESEED_TOKEN, parser=parse_reseed_group)
|
||||
|
||||
|
||||
def skip_group(group) -> Any:
|
||||
return group
|
||||
|
||||
|
||||
def remove_tokens(prompt: Optional[str]) -> Optional[str]:
|
||||
if prompt is None:
|
||||
return prompt
|
||||
|
||||
remainder, tokens = get_tokens_from_prompt(prompt, ANY_TOKEN, parser=skip_group)
|
||||
return remainder
|
||||
|
|
|
@ -12,6 +12,11 @@ try:
|
|||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
||||
|
||||
try:
|
||||
from diffusers import LCMScheduler
|
||||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as LCMScheduler
|
||||
|
||||
try:
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
except ImportError:
|
||||
|
|
|
@ -8,7 +8,7 @@ def mask_filter_none(
|
|||
) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(mask.mode, (width, height), fill)
|
||||
noise.paste(mask, origin)
|
||||
|
||||
return noise
|
||||
|
|
|
@ -17,21 +17,21 @@ def noise_source_fill_edge(
|
|||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(source.mode, (width, height), fill)
|
||||
noise.paste(source, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_mask(
|
||||
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Fill the whole canvas, no source or noise.
|
||||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(source.mode, (width, height), fill)
|
||||
|
||||
return noise
|
||||
|
||||
|
@ -52,7 +52,7 @@ def noise_source_gaussian(
|
|||
|
||||
|
||||
def noise_source_uniform(
|
||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -61,6 +61,7 @@ def noise_source_uniform(
|
|||
noise_g = random.uniform(0, 256, size=size)
|
||||
noise_b = random.uniform(0, 256, size=size)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -68,11 +69,11 @@ def noise_source_uniform(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
||||
|
||||
def noise_source_normal(
|
||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -81,6 +82,7 @@ def noise_source_normal(
|
|||
noise_g = random.normal(128, 32, size=size)
|
||||
noise_b = random.normal(128, 32, size=size)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -88,13 +90,13 @@ def noise_source_normal(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
||||
|
||||
def noise_source_histogram(
|
||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
r, g, b = source.split()
|
||||
r, g, b, *_a = source.split()
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -112,6 +114,7 @@ def noise_source_histogram(
|
|||
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
||||
)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -119,4 +122,4 @@ def noise_source_histogram(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
|
|
@ -47,7 +47,7 @@ def source_filter_noise(
|
|||
source: Image.Image,
|
||||
strength: float = 0.5,
|
||||
):
|
||||
noise = noise_source_histogram(source, source.size)
|
||||
noise = noise_source_histogram(source, source.size, (0, 0))
|
||||
return ImageChops.blend(source, noise, strength)
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Tuple
|
||||
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from ..params import Border, Size
|
||||
|
@ -13,12 +15,12 @@ def expand_image(
|
|||
fill="white",
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
|
||||
size = Size(*source.size).add_border(expand)
|
||||
size = tuple(size)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
full_source = Image.new("RGB", size, fill)
|
||||
full_source = Image.new(source.mode, size, fill)
|
||||
full_source.paste(source, origin)
|
||||
|
||||
# new mask pixels need to be filled with white so they will be replaced
|
||||
|
|
|
@ -23,6 +23,7 @@ from .server.load import (
|
|||
load_platforms,
|
||||
load_wildcards,
|
||||
)
|
||||
from .server.plugin import load_plugins, register_plugins
|
||||
from .server.static import register_static_routes
|
||||
from .server.utils import check_paths
|
||||
from .utils import is_debug
|
||||
|
@ -43,15 +44,32 @@ def main():
|
|||
server = ServerContext.from_environ()
|
||||
apply_patches(server)
|
||||
check_paths(server)
|
||||
|
||||
# debug options
|
||||
if server.debug:
|
||||
import debugpy
|
||||
|
||||
debugpy.listen(5678)
|
||||
logger.warning("waiting for debugger")
|
||||
debugpy.wait_for_client()
|
||||
gc.set_debug(gc.DEBUG_STATS)
|
||||
|
||||
# register plugins
|
||||
exports = load_plugins(server)
|
||||
success = register_plugins(exports)
|
||||
if success:
|
||||
logger.info("all plugins loaded successfully")
|
||||
else:
|
||||
logger.warning("error loading plugins")
|
||||
|
||||
# load additional resources
|
||||
load_extras(server)
|
||||
load_models(server)
|
||||
load_params(server)
|
||||
load_platforms(server)
|
||||
load_wildcards(server)
|
||||
|
||||
if is_debug():
|
||||
gc.set_debug(gc.DEBUG_STATS)
|
||||
|
||||
# misc server options
|
||||
if not server.show_progress:
|
||||
disable_progress_bar()
|
||||
disable_progress_bars()
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
from typing import Literal
|
||||
from typing import List, Literal
|
||||
|
||||
NetworkType = Literal["inversion", "lora"]
|
||||
NetworkType = Literal["control", "inversion", "lora"]
|
||||
|
||||
|
||||
class NetworkModel:
|
||||
name: str
|
||||
tokens: List[str]
|
||||
type: NetworkType
|
||||
|
||||
def __init__(self, name: str, type: NetworkType) -> None:
|
||||
def __init__(self, name: str, type: NetworkType, tokens=None) -> None:
|
||||
self.name = name
|
||||
self.tokens = tokens or []
|
||||
self.type = type
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"tokens": self.tokens,
|
||||
"type": self.type,
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ def json_params(
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
parent: Dict = None,
|
||||
parent: Optional[Dict] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"input_size": size.tojson(),
|
||||
|
@ -158,6 +158,7 @@ def make_output_name(
|
|||
size: Size,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
count: Optional[int] = None,
|
||||
offset: int = 0,
|
||||
) -> List[str]:
|
||||
count = count or params.batch
|
||||
now = int(time())
|
||||
|
@ -183,7 +184,7 @@ def make_output_name(
|
|||
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||
for i in range(count)
|
||||
for i in range(offset, count + offset)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ Point = Tuple[int, int]
|
|||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
unlimited = 0
|
||||
micro = 64
|
||||
mini = 128 # small tile for very expensive models
|
||||
half = 256 # half tile for outpainting
|
||||
auto = 512 # auto tile size
|
||||
|
@ -25,6 +25,7 @@ class SizeChart(IntEnum):
|
|||
hd16k = 2**14
|
||||
hd32k = 2**15
|
||||
hd64k = 2**16
|
||||
max = 2**32 # should be a reasonable upper limit for now
|
||||
|
||||
|
||||
class TileOrder:
|
||||
|
@ -140,7 +141,7 @@ class DeviceParams:
|
|||
if self.options is None:
|
||||
return self.provider
|
||||
else:
|
||||
return self.provider # (self.provider, self.options)
|
||||
return (self.provider, self.options)
|
||||
|
||||
def sess_options(self, cache=True) -> SessionOptions:
|
||||
if cache and self.sess_options_cache is not None:
|
||||
|
@ -201,11 +202,14 @@ class ImageParams:
|
|||
batch: int
|
||||
control: Optional[NetworkModel]
|
||||
input_prompt: str
|
||||
input_negative_prompt: str
|
||||
input_negative_prompt: Optional[str]
|
||||
loopback: int
|
||||
tiled_vae: bool
|
||||
tiles: int
|
||||
overlap: float
|
||||
unet_tile: int
|
||||
unet_overlap: float
|
||||
vae_tile: int
|
||||
vae_overlap: float
|
||||
denoise: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -224,9 +228,11 @@ class ImageParams:
|
|||
input_negative_prompt: Optional[str] = None,
|
||||
loopback: int = 0,
|
||||
tiled_vae: bool = False,
|
||||
tiles: int = 512,
|
||||
overlap: float = 0.25,
|
||||
stride: int = 64,
|
||||
unet_overlap: float = 0.25,
|
||||
unet_tile: int = 512,
|
||||
vae_overlap: float = 0.25,
|
||||
vae_tile: int = 512,
|
||||
denoise: int = 3,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
|
@ -243,14 +249,16 @@ class ImageParams:
|
|||
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
||||
self.loopback = loopback
|
||||
self.tiled_vae = tiled_vae
|
||||
self.tiles = tiles
|
||||
self.overlap = overlap
|
||||
self.stride = stride
|
||||
self.unet_overlap = unet_overlap
|
||||
self.unet_tile = unet_tile
|
||||
self.vae_overlap = vae_overlap
|
||||
self.vae_tile = vae_tile
|
||||
self.denoise = denoise
|
||||
|
||||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
||||
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
|
||||
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
|
||||
pipeline = pipeline or self.pipeline
|
||||
|
||||
# if the correct pipeline was already requested, simply use that
|
||||
|
@ -259,7 +267,14 @@ class ImageParams:
|
|||
|
||||
# otherwise, check for additional allowed pipelines
|
||||
if group == "img2img":
|
||||
if pipeline in ["controlnet", "img2img-sdxl", "lpw", "panorama", "pix2pix"]:
|
||||
if pipeline in [
|
||||
"controlnet",
|
||||
"img2img-sdxl",
|
||||
"lpw",
|
||||
"panorama",
|
||||
"panorama-sdxl",
|
||||
"pix2pix",
|
||||
]:
|
||||
return pipeline
|
||||
elif pipeline == "txt2img-sdxl":
|
||||
return "img2img-sdxl"
|
||||
|
@ -267,7 +282,7 @@ class ImageParams:
|
|||
if pipeline in ["controlnet", "lpw", "panorama"]:
|
||||
return pipeline
|
||||
elif group == "txt2img":
|
||||
if pipeline in ["lpw", "panorama", "txt2img-sdxl"]:
|
||||
if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]:
|
||||
return pipeline
|
||||
|
||||
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
||||
|
@ -280,7 +295,7 @@ class ImageParams:
|
|||
return self.pipeline == "lpw"
|
||||
|
||||
def is_panorama(self):
|
||||
return self.pipeline == "panorama"
|
||||
return self.pipeline in ["panorama", "panorama-sdxl"]
|
||||
|
||||
def is_pix2pix(self):
|
||||
return self.pipeline == "pix2pix"
|
||||
|
@ -305,9 +320,11 @@ class ImageParams:
|
|||
"input_negative_prompt": self.input_negative_prompt,
|
||||
"loopback": self.loopback,
|
||||
"tiled_vae": self.tiled_vae,
|
||||
"tiles": self.tiles,
|
||||
"overlap": self.overlap,
|
||||
"stride": self.stride,
|
||||
"unet_overlap": self.unet_overlap,
|
||||
"unet_tile": self.unet_tile,
|
||||
"vae_overlap": self.vae_overlap,
|
||||
"vae_tile": self.vae_tile,
|
||||
"denoise": self.denoise,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -327,9 +344,11 @@ class ImageParams:
|
|||
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
||||
kwargs.get("loopback", self.loopback),
|
||||
kwargs.get("tiled_vae", self.tiled_vae),
|
||||
kwargs.get("tiles", self.tiles),
|
||||
kwargs.get("overlap", self.overlap),
|
||||
kwargs.get("stride", self.stride),
|
||||
kwargs.get("unet_overlap", self.unet_overlap),
|
||||
kwargs.get("unet_tile", self.unet_tile),
|
||||
kwargs.get("vae_overlap", self.vae_overlap),
|
||||
kwargs.get("vae_tile", self.vae_tile),
|
||||
kwargs.get("denoise", self.denoise),
|
||||
)
|
||||
|
||||
|
||||
|
@ -351,6 +370,17 @@ class StageParams:
|
|||
self.tile_order = tile_order
|
||||
self.tile_size = tile_size
|
||||
|
||||
def with_args(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
return StageParams(
|
||||
name=kwargs.get("name", self.name),
|
||||
outscale=kwargs.get("outscale", self.outscale),
|
||||
tile_order=kwargs.get("tile_order", self.tile_order),
|
||||
tile_size=kwargs.get("tile_size", self.tile_size),
|
||||
)
|
||||
|
||||
|
||||
class UpscaleParams:
|
||||
def __init__(
|
||||
|
@ -459,10 +489,14 @@ class HighresParams:
|
|||
self.method = method
|
||||
self.iterations = iterations
|
||||
|
||||
def outscale(self) -> int:
|
||||
return self.scale**self.iterations
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
outscale = self.outscale()
|
||||
return Size(
|
||||
size.width * (self.scale**self.iterations),
|
||||
size.height * (self.scale**self.iterations),
|
||||
size.width * outscale,
|
||||
size.height * outscale,
|
||||
)
|
||||
|
||||
def tojson(self):
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict
|
||||
|
||||
from flask import Flask, jsonify, make_response, request, url_for
|
||||
from jsonschema import validate
|
||||
from PIL import Image
|
||||
|
||||
from ..chain import CHAIN_STAGES, ChainPipeline
|
||||
from ..chain.result import StageResult
|
||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||
from ..diffusers.run import (
|
||||
run_blend_pipeline,
|
||||
|
@ -17,7 +19,7 @@ from ..diffusers.run import (
|
|||
)
|
||||
from ..diffusers.utils import replace_wildcards
|
||||
from ..output import json_params, make_output_name
|
||||
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
||||
from ..params import Size, StageParams, TileOrder
|
||||
from ..transformers.run import run_txt2txt_pipeline
|
||||
from ..utils import (
|
||||
base_join,
|
||||
|
@ -49,10 +51,11 @@ from .load import (
|
|||
get_wildcard_data,
|
||||
)
|
||||
from .params import (
|
||||
border_from_request,
|
||||
highres_from_request,
|
||||
build_border,
|
||||
build_highres,
|
||||
build_upscale,
|
||||
pipeline_from_json,
|
||||
pipeline_from_request,
|
||||
upscale_from_request,
|
||||
)
|
||||
from .utils import wrap_route
|
||||
|
||||
|
@ -167,8 +170,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
size = Size(source.width, source.height)
|
||||
|
||||
device, params, _size = pipeline_from_request(server, "img2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
source_filter = get_from_list(
|
||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||
)
|
||||
|
@ -216,12 +219,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server, "txt2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output = make_output_name(server, "txt2img", params, size)
|
||||
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
||||
|
||||
job_name = output[0]
|
||||
pool.submit(
|
||||
|
@ -250,7 +253,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
if mask_file is None:
|
||||
return error_reply("mask image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||
size = Size(source.width, source.height)
|
||||
|
||||
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
|
@ -270,9 +273,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
expand = build_border()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||
|
@ -340,8 +343,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
|
@ -366,47 +369,70 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
|
||||
|
||||
# keys that are specially parsed by params and should not show up in with_args
|
||||
CHAIN_POP_KEYS = ["model", "control"]
|
||||
|
||||
|
||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if request.is_json:
|
||||
logger.debug("chain pipeline request with JSON body")
|
||||
data = request.get_json()
|
||||
else:
|
||||
logger.debug(
|
||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||
)
|
||||
|
||||
body = request.form.get("chain") or request.files.get("chain")
|
||||
if body is None:
|
||||
return error_reply("chain pipeline must have a body")
|
||||
|
||||
data = load_config_str(body)
|
||||
schema = load_config("./schemas/chain.yaml")
|
||||
|
||||
schema = load_config("./schemas/chain.yaml")
|
||||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, params, size = pipeline_from_request(server)
|
||||
output = make_output_name(server, "chain", params, size)
|
||||
job_name = output[0]
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
device, base_params, base_size = pipeline_from_json(
|
||||
server, data=data.get("defaults")
|
||||
)
|
||||
|
||||
# start building the pipeline
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get("stages", []):
|
||||
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||
kwargs = stage_data.get("params", {})
|
||||
kwargs: Dict[str, Any] = stage_data.get("params", {})
|
||||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||
|
||||
# TODO: combine base params with stage params
|
||||
_device, params, size = pipeline_from_json(server, data=kwargs)
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
# remove parsed keys, like model names (which become paths)
|
||||
for pop_key in CHAIN_POP_KEYS:
|
||||
if pop_key in kwargs:
|
||||
kwargs.pop(pop_key)
|
||||
|
||||
if "seed" in kwargs and kwargs["seed"] == -1:
|
||||
kwargs.pop("seed")
|
||||
|
||||
# replace kwargs with parsed versions
|
||||
kwargs["params"] = params
|
||||
kwargs["size"] = size
|
||||
|
||||
border = build_border(kwargs)
|
||||
kwargs["border"] = border
|
||||
|
||||
upscale = build_upscale(kwargs)
|
||||
kwargs["upscale"] = upscale
|
||||
|
||||
# prepare the stage metadata
|
||||
stage = StageParams(
|
||||
stage_data.get("name", stage_class.__name__),
|
||||
tile_size=get_size(kwargs.get("tile_size")),
|
||||
tile_size=get_size(kwargs.get("tiles")),
|
||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||
)
|
||||
|
||||
if "border" in kwargs:
|
||||
border = Border.even(int(kwargs.get("border")))
|
||||
kwargs["border"] = border
|
||||
|
||||
if "upscale" in kwargs:
|
||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
||||
kwargs["upscale"] = upscale
|
||||
|
||||
# load any images related to this stage
|
||||
stage_source_name = "source:%s" % (stage.name)
|
||||
stage_mask_name = "mask:%s" % (stage.name)
|
||||
|
||||
|
@ -436,20 +462,25 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(
|
||||
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
|
||||
)
|
||||
job_name = output[0]
|
||||
|
||||
# build and run chain pipeline
|
||||
empty_source = Image.new("RGB", (size.width, size.height))
|
||||
pool.submit(
|
||||
job_name,
|
||||
pipeline,
|
||||
server,
|
||||
params,
|
||||
empty_source,
|
||||
output=output[0],
|
||||
size=size,
|
||||
base_params,
|
||||
StageResult.empty(),
|
||||
output=output,
|
||||
size=base_size,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
return jsonify(json_params(output, step_params, base_size))
|
||||
|
||||
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -471,7 +502,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
sources.append(source)
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
upscale = build_upscale()
|
||||
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
job_name = output[0]
|
||||
|
|
|
@ -5,18 +5,44 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from ..utils import get_boolean
|
||||
from ..utils import get_boolean, get_list
|
||||
from .model_cache import ModelCache
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_ANY_PLATFORM = True
|
||||
DEFAULT_CACHE_LIMIT = 5
|
||||
DEFAULT_JOB_LIMIT = 10
|
||||
DEFAULT_IMAGE_FORMAT = "png"
|
||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||
DEFAULT_SHOW_PROGRESS = True
|
||||
DEFAULT_WORKER_RETRIES = 3
|
||||
|
||||
|
||||
class ServerContext:
|
||||
bundle_path: str
|
||||
model_path: str
|
||||
output_path: str
|
||||
params_path: str
|
||||
cors_origin: str
|
||||
any_platform: bool
|
||||
block_platforms: List[str]
|
||||
default_platform: str
|
||||
image_format: str
|
||||
cache_limit: int
|
||||
cache_path: str
|
||||
show_progress: bool
|
||||
optimizations: List[str]
|
||||
extra_models: List[str]
|
||||
job_limit: int
|
||||
memory_limit: int
|
||||
admin_token: str
|
||||
server_version: str
|
||||
worker_retries: int
|
||||
feature_flags: List[str]
|
||||
plugins: List[str]
|
||||
debug: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bundle_path: str = ".",
|
||||
|
@ -24,19 +50,23 @@ class ServerContext:
|
|||
output_path: str = ".",
|
||||
params_path: str = ".",
|
||||
cors_origin: str = "*",
|
||||
any_platform: bool = True,
|
||||
any_platform: bool = DEFAULT_ANY_PLATFORM,
|
||||
block_platforms: Optional[List[str]] = None,
|
||||
default_platform: Optional[str] = None,
|
||||
image_format: str = DEFAULT_IMAGE_FORMAT,
|
||||
cache_limit: int = DEFAULT_CACHE_LIMIT,
|
||||
cache_path: Optional[str] = None,
|
||||
show_progress: bool = True,
|
||||
show_progress: bool = DEFAULT_SHOW_PROGRESS,
|
||||
optimizations: Optional[List[str]] = None,
|
||||
extra_models: Optional[List[str]] = None,
|
||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||
memory_limit: Optional[int] = None,
|
||||
admin_token: Optional[str] = None,
|
||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||
feature_flags: Optional[List[str]] = None,
|
||||
plugins: Optional[List[str]] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -56,6 +86,10 @@ class ServerContext:
|
|||
self.memory_limit = memory_limit
|
||||
self.admin_token = admin_token or token_urlsafe()
|
||||
self.server_version = server_version
|
||||
self.worker_retries = worker_retries
|
||||
self.feature_flags = feature_flags or []
|
||||
self.plugins = plugins or []
|
||||
self.debug = debug
|
||||
|
||||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
|
@ -72,26 +106,41 @@ class ServerContext:
|
|||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||
# others
|
||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
|
||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||
any_platform=get_boolean(
|
||||
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||
),
|
||||
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
|
||||
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
|
||||
extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","),
|
||||
show_progress=get_boolean(
|
||||
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||
),
|
||||
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
|
||||
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
|
||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||
memory_limit=memory_limit,
|
||||
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||
server_version=environ.get(
|
||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
||||
),
|
||||
worker_retries=int(
|
||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||
),
|
||||
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
|
||||
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
|
||||
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
|
||||
)
|
||||
|
||||
def has_feature(self, flag: str) -> bool:
|
||||
return flag in self.feature_flags
|
||||
|
||||
def has_optimization(self, opt: str) -> bool:
|
||||
return opt in self.optimizations
|
||||
|
||||
def torch_dtype(self):
|
||||
if "torch-fp16" in self.optimizations:
|
||||
if self.has_optimization("torch-fp16"):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
|
|
@ -134,25 +134,44 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
|
|||
|
||||
def apply_patch_basicsr(server: ServerContext):
|
||||
logger.debug("patching BasicSR module")
|
||||
try:
|
||||
import basicsr.utils.download_util
|
||||
|
||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
||||
basicsr.utils.download_util.load_file_from_url = partial(
|
||||
patch_cache_path, server
|
||||
)
|
||||
except ImportError:
|
||||
logger.info("unable to import basicsr utils for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch basicsr utils")
|
||||
|
||||
|
||||
def apply_patch_codeformer(server: ServerContext):
|
||||
logger.debug("patching CodeFormer module")
|
||||
try:
|
||||
import codeformer.facelib.utils.misc
|
||||
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
||||
patch_cache_path, server
|
||||
)
|
||||
except ImportError:
|
||||
logger.info("unable to import codeformer utils for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch codeformer utils")
|
||||
|
||||
|
||||
def apply_patch_facexlib(server: ServerContext):
|
||||
logger.debug("patching Facexlib module")
|
||||
try:
|
||||
import facexlib.utils
|
||||
|
||||
facexlib.utils.load_file_from_url = partial(patch_cache_path, server)
|
||||
except ImportError:
|
||||
logger.info("unable to import facexlib for patching")
|
||||
except AttributeError:
|
||||
logger.warning("unable to patch facexlib utils")
|
||||
|
||||
|
||||
def apply_patches(server: ServerContext):
|
||||
|
|
|
@ -96,6 +96,7 @@ wildcard_data: Dict[str, List[str]] = defaultdict(list)
|
|||
# Loaded from extra_models
|
||||
extra_hashes: Dict[str, str] = {}
|
||||
extra_strings: Dict[str, Any] = {}
|
||||
extra_tokens: Dict[str, List[str]] = {}
|
||||
|
||||
|
||||
def get_config_params():
|
||||
|
@ -160,9 +161,10 @@ def load_extras(server: ServerContext):
|
|||
"""
|
||||
global extra_hashes
|
||||
global extra_strings
|
||||
global extra_tokens
|
||||
|
||||
labels = {}
|
||||
strings = {}
|
||||
labels: Dict[str, str] = {}
|
||||
strings: Dict[str, Any] = {}
|
||||
|
||||
extra_schema = load_config("./schemas/extras.yaml")
|
||||
|
||||
|
@ -210,6 +212,14 @@ def load_extras(server: ServerContext):
|
|||
else:
|
||||
labels[model_name] = model["label"]
|
||||
|
||||
if "tokens" in model:
|
||||
logger.debug(
|
||||
"collecting tokens for model %s from %s",
|
||||
model_name,
|
||||
file,
|
||||
)
|
||||
extra_tokens[model_name] = model["tokens"]
|
||||
|
||||
if "inversions" in model:
|
||||
for inversion in model["inversions"]:
|
||||
if "label" in inversion:
|
||||
|
@ -353,7 +363,10 @@ def load_models(server: ServerContext) -> None:
|
|||
)
|
||||
logger.debug("loaded Textual Inversion models from disk: %s", inversion_models)
|
||||
network_models.extend(
|
||||
[NetworkModel(model, "inversion") for model in inversion_models]
|
||||
[
|
||||
NetworkModel(model, "inversion", tokens=extra_tokens.get(model, []))
|
||||
for model in inversion_models
|
||||
]
|
||||
)
|
||||
|
||||
lora_models = list_model_globs(
|
||||
|
@ -364,7 +377,12 @@ def load_models(server: ServerContext) -> None:
|
|||
base_path=path.join(server.model_path, "lora"),
|
||||
)
|
||||
logger.debug("loaded LoRA models from disk: %s", lora_models)
|
||||
network_models.extend([NetworkModel(model, "lora") for model in lora_models])
|
||||
network_models.extend(
|
||||
[
|
||||
NetworkModel(model, "lora", tokens=extra_tokens.get(model, []))
|
||||
for model in lora_models
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def load_params(server: ServerContext) -> None:
|
||||
|
@ -397,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
|
|||
):
|
||||
if potential == "cuda" or potential == "rocm":
|
||||
for i in range(torch.cuda.device_count()):
|
||||
options = {
|
||||
options: Dict[str, Union[int, str]] = {
|
||||
"device_id": i,
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class ModelCache:
|
|||
return
|
||||
|
||||
for i in range(len(cache)):
|
||||
t, k, v = cache[i]
|
||||
t, k, _v = cache[i]
|
||||
if tag == t and key != k:
|
||||
logger.debug("updating model cache: %s %s", tag, key)
|
||||
cache[i] = (tag, key, value)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from logging import getLogger
|
||||
from typing import Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from flask import request
|
||||
|
||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||
from ..diffusers.utils import random_seed
|
||||
from ..params import (
|
||||
Border,
|
||||
DeviceParams,
|
||||
|
@ -34,143 +34,122 @@ from .utils import get_model_path
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
def build_device(
|
||||
_server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
) -> Optional[DeviceParams]:
|
||||
# platform stuff
|
||||
device = None
|
||||
device_name = request.args.get("platform")
|
||||
device_name = data.get("platform")
|
||||
|
||||
if device_name is not None and device_name != "any":
|
||||
for platform in get_available_platforms():
|
||||
if platform.device == device_name:
|
||||
device = platform
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def build_params(
|
||||
server: ServerContext,
|
||||
default_pipeline: str,
|
||||
data: Dict[str, str],
|
||||
) -> ImageParams:
|
||||
# diffusion model
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model = get_not_empty(data, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
||||
control = None
|
||||
control_name = request.args.get("control")
|
||||
control_name = data.get("control")
|
||||
for network in get_network_models():
|
||||
if network.name == control_name:
|
||||
control = network
|
||||
|
||||
# pipeline stuff
|
||||
pipeline = get_from_list(
|
||||
request.args, "pipeline", get_available_pipelines(), default_pipeline
|
||||
data, "pipeline", get_available_pipelines(), default_pipeline
|
||||
)
|
||||
scheduler = get_from_list(request.args, "scheduler", get_pipeline_schedulers())
|
||||
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = get_config_value("scheduler")
|
||||
|
||||
# prompt does not come from config
|
||||
prompt = request.args.get("prompt", "")
|
||||
negative_prompt = request.args.get("negativePrompt", None)
|
||||
prompt = data.get("prompt", "")
|
||||
negative_prompt = data.get("negativePrompt", None)
|
||||
|
||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||
negative_prompt = None
|
||||
|
||||
# image params
|
||||
batch = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"batch",
|
||||
get_config_value("batch"),
|
||||
get_config_value("batch", "max"),
|
||||
get_config_value("batch", "min"),
|
||||
)
|
||||
cfg = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"cfg",
|
||||
get_config_value("cfg"),
|
||||
get_config_value("cfg", "max"),
|
||||
get_config_value("cfg", "min"),
|
||||
)
|
||||
eta = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"eta",
|
||||
get_config_value("eta"),
|
||||
get_config_value("eta", "max"),
|
||||
get_config_value("eta", "min"),
|
||||
)
|
||||
loopback = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"loopback",
|
||||
get_config_value("loopback"),
|
||||
get_config_value("loopback", "max"),
|
||||
get_config_value("loopback", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"steps",
|
||||
get_config_value("steps"),
|
||||
get_config_value("steps", "max"),
|
||||
get_config_value("steps", "min"),
|
||||
)
|
||||
height = get_and_clamp_int(
|
||||
request.args,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
|
||||
unet_overlap = get_and_clamp_float(
|
||||
data,
|
||||
"unet_overlap",
|
||||
get_config_value("unet_overlap"),
|
||||
get_config_value("unet_overlap", "max"),
|
||||
get_config_value("unet_overlap", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
request.args,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
unet_tile = get_and_clamp_int(
|
||||
data,
|
||||
"unet_tile",
|
||||
get_config_value("unet_tile"),
|
||||
get_config_value("unet_tile", "max"),
|
||||
get_config_value("unet_tile", "min"),
|
||||
)
|
||||
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
|
||||
tiles = get_and_clamp_int(
|
||||
request.args,
|
||||
"tiles",
|
||||
get_config_value("tiles"),
|
||||
get_config_value("tiles", "max"),
|
||||
get_config_value("tiles", "min"),
|
||||
vae_overlap = get_and_clamp_float(
|
||||
data,
|
||||
"vae_overlap",
|
||||
get_config_value("vae_overlap"),
|
||||
get_config_value("vae_overlap", "max"),
|
||||
get_config_value("vae_overlap", "min"),
|
||||
)
|
||||
overlap = get_and_clamp_float(
|
||||
request.args,
|
||||
"overlap",
|
||||
get_config_value("overlap"),
|
||||
get_config_value("overlap", "max"),
|
||||
get_config_value("overlap", "min"),
|
||||
)
|
||||
stride = get_and_clamp_int(
|
||||
request.args,
|
||||
"stride",
|
||||
get_config_value("stride"),
|
||||
get_config_value("stride", "max"),
|
||||
get_config_value("stride", "min"),
|
||||
vae_tile = get_and_clamp_int(
|
||||
data,
|
||||
"vae_tile",
|
||||
get_config_value("vae_tile"),
|
||||
get_config_value("vae_tile", "max"),
|
||||
get_config_value("vae_tile", "min"),
|
||||
)
|
||||
|
||||
if stride > tiles:
|
||||
logger.info("limiting stride to tile size, %s > %s", stride, tiles)
|
||||
stride = tiles
|
||||
|
||||
seed = int(request.args.get("seed", -1))
|
||||
seed = int(data.get("seed", -1))
|
||||
if seed == -1:
|
||||
# this one can safely use np.random because it produces a single value
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.info(
|
||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
steps,
|
||||
scheduler,
|
||||
model_path,
|
||||
pipeline,
|
||||
device or "any device",
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
seed,
|
||||
prompt,
|
||||
)
|
||||
seed = random_seed()
|
||||
|
||||
params = ImageParams(
|
||||
model_path,
|
||||
|
@ -186,38 +165,65 @@ def pipeline_from_request(
|
|||
control=control,
|
||||
loopback=loopback,
|
||||
tiled_vae=tiled_vae,
|
||||
tiles=tiles,
|
||||
overlap=overlap,
|
||||
stride=stride,
|
||||
unet_overlap=unet_overlap,
|
||||
unet_tile=unet_tile,
|
||||
vae_overlap=vae_overlap,
|
||||
vae_tile=vae_tile,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def border_from_request() -> Border:
|
||||
def build_size(
|
||||
_server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
) -> Size:
|
||||
height = get_and_clamp_int(
|
||||
data,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
data,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
return Size(width, height)
|
||||
|
||||
|
||||
def build_border(
|
||||
data: Dict[str, str] = None,
|
||||
) -> Border:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
left = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"left",
|
||||
get_config_value("left"),
|
||||
get_config_value("left", "max"),
|
||||
get_config_value("left", "min"),
|
||||
)
|
||||
right = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"right",
|
||||
get_config_value("right"),
|
||||
get_config_value("right", "max"),
|
||||
get_config_value("right", "min"),
|
||||
)
|
||||
top = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"top",
|
||||
get_config_value("top"),
|
||||
get_config_value("top", "max"),
|
||||
get_config_value("top", "min"),
|
||||
)
|
||||
bottom = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"bottom",
|
||||
get_config_value("bottom"),
|
||||
get_config_value("bottom", "max"),
|
||||
|
@ -227,46 +233,51 @@ def border_from_request() -> Border:
|
|||
return Border(left, right, top, bottom)
|
||||
|
||||
|
||||
def upscale_from_request() -> UpscaleParams:
|
||||
def build_upscale(
|
||||
data: Dict[str, str] = None,
|
||||
) -> UpscaleParams:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
denoise = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"denoise",
|
||||
get_config_value("denoise"),
|
||||
get_config_value("denoise", "max"),
|
||||
get_config_value("denoise", "min"),
|
||||
)
|
||||
scale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"scale",
|
||||
get_config_value("scale"),
|
||||
get_config_value("scale", "max"),
|
||||
get_config_value("scale", "min"),
|
||||
)
|
||||
outscale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"outscale",
|
||||
get_config_value("outscale"),
|
||||
get_config_value("outscale", "max"),
|
||||
get_config_value("outscale", "min"),
|
||||
)
|
||||
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
|
||||
correction = get_from_list(request.args, "correction", get_correction_models())
|
||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
|
||||
correction = get_from_list(data, "correction", get_correction_models())
|
||||
faces = get_not_empty(data, "faces", "false") == "true"
|
||||
face_outscale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"faceOutscale",
|
||||
get_config_value("faceOutscale"),
|
||||
get_config_value("faceOutscale", "max"),
|
||||
get_config_value("faceOutscale", "min"),
|
||||
)
|
||||
face_strength = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"faceStrength",
|
||||
get_config_value("faceStrength"),
|
||||
get_config_value("faceStrength", "max"),
|
||||
get_config_value("faceStrength", "min"),
|
||||
)
|
||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
||||
upscale_order = data.get("upscaleOrder", "correction-first")
|
||||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
|
@ -282,37 +293,43 @@ def upscale_from_request() -> UpscaleParams:
|
|||
)
|
||||
|
||||
|
||||
def highres_from_request() -> HighresParams:
|
||||
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
|
||||
def build_highres(
|
||||
data: Dict[str, str] = None,
|
||||
) -> HighresParams:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
enabled = get_boolean(data, "highres", get_config_value("highres"))
|
||||
iterations = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresIterations",
|
||||
get_config_value("highresIterations"),
|
||||
get_config_value("highresIterations", "max"),
|
||||
get_config_value("highresIterations", "min"),
|
||||
)
|
||||
method = get_from_list(request.args, "highresMethod", get_highres_methods())
|
||||
method = get_from_list(data, "highresMethod", get_highres_methods())
|
||||
scale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresScale",
|
||||
get_config_value("highresScale"),
|
||||
get_config_value("highresScale", "max"),
|
||||
get_config_value("highresScale", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresSteps",
|
||||
get_config_value("highresSteps"),
|
||||
get_config_value("highresSteps", "max"),
|
||||
get_config_value("highresSteps", "min"),
|
||||
)
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"highresStrength",
|
||||
get_config_value("highresStrength"),
|
||||
get_config_value("highresStrength", "max"),
|
||||
get_config_value("highresStrength", "min"),
|
||||
)
|
||||
|
||||
return HighresParams(
|
||||
enabled,
|
||||
scale,
|
||||
|
@ -321,3 +338,50 @@ def highres_from_request() -> HighresParams:
|
|||
method=method,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
|
||||
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
||||
|
||||
|
||||
def pipeline_from_json(
|
||||
server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> PipelineParams:
|
||||
"""
|
||||
Like pipeline_from_request but expects a nested structure.
|
||||
"""
|
||||
|
||||
device = build_device(server, data.get("device", data))
|
||||
params = build_params(server, default_pipeline, data.get("params", data))
|
||||
size = build_size(server, data.get("params", data))
|
||||
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> PipelineParams:
|
||||
user = request.remote_addr
|
||||
|
||||
device = build_device(server, request.args)
|
||||
params = build_params(server, default_pipeline, request.args)
|
||||
size = build_size(server, request.args)
|
||||
|
||||
logger.info(
|
||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
params.steps,
|
||||
params.scheduler,
|
||||
params.model,
|
||||
params.pipeline,
|
||||
device or "any device",
|
||||
size.width,
|
||||
size.height,
|
||||
params.cfg,
|
||||
params.seed,
|
||||
params.prompt,
|
||||
)
|
||||
|
||||
return (device, params, size)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from importlib import import_module
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from onnx_web.chain.stages import add_stage
|
||||
from onnx_web.diffusers.load import add_pipeline
|
||||
from onnx_web.server.context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class PluginExports:
|
||||
pipelines: Dict[str, Any]
|
||||
stages: Dict[str, Any]
|
||||
|
||||
def __init__(self, pipelines=None, stages=None) -> None:
|
||||
self.pipelines = pipelines or {}
|
||||
self.stages = stages or {}
|
||||
|
||||
|
||||
PluginModule = Callable[[ServerContext], PluginExports]
|
||||
|
||||
|
||||
def load_plugins(server: ServerContext) -> PluginExports:
|
||||
combined_exports = PluginExports()
|
||||
|
||||
for plugin in server.plugins:
|
||||
logger.info("loading plugin module: %s", plugin)
|
||||
try:
|
||||
module: PluginModule = import_module(plugin)
|
||||
exports = module(server)
|
||||
|
||||
for name, pipeline in exports.pipelines.items():
|
||||
if name in combined_exports.pipelines:
|
||||
logger.warning(
|
||||
"multiple plugins exported a pipeline named %s", name
|
||||
)
|
||||
else:
|
||||
combined_exports.pipelines[name] = pipeline
|
||||
|
||||
for name, stage in exports.stages.items():
|
||||
if name in combined_exports.stages:
|
||||
logger.warning("multiple plugins exported a stage named %s", name)
|
||||
else:
|
||||
combined_exports.stages[name] = stage
|
||||
except Exception:
|
||||
logger.exception("error importing plugin")
|
||||
|
||||
return combined_exports
|
||||
|
||||
|
||||
def register_plugins(exports: PluginExports) -> bool:
|
||||
success = True
|
||||
|
||||
for name, pipeline in exports.pipelines.items():
|
||||
success = success and add_pipeline(name, pipeline)
|
||||
|
||||
for name, stage in exports.stages.items():
|
||||
success = success and add_stage(name, stage)
|
||||
|
||||
return success
|
|
@ -18,6 +18,11 @@ logger = getLogger(__name__)
|
|||
SAFE_CHARS = "._-"
|
||||
|
||||
|
||||
def split_list(val: str) -> List[str]:
|
||||
parts = [part.strip() for part in val.split(",")]
|
||||
return [part for part in parts if len(part) > 0]
|
||||
|
||||
|
||||
def base_join(base: str, tail: str) -> str:
|
||||
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||
return path.join(base, tail_path)
|
||||
|
@ -28,7 +33,16 @@ def is_debug() -> bool:
|
|||
|
||||
|
||||
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
||||
return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes")
|
||||
val = args.get(key, str(default_value))
|
||||
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
|
||||
return val.lower() in ("1", "t", "true", "y", "yes")
|
||||
|
||||
|
||||
def get_list(args: Any, key: str, default="") -> List[str]:
|
||||
return split_list(args.get(key, default))
|
||||
|
||||
|
||||
def get_and_clamp_float(
|
||||
|
@ -61,13 +75,13 @@ def get_from_list(
|
|||
|
||||
|
||||
def get_from_map(
|
||||
args: Any, key: str, values: Dict[str, TElem], default: TElem
|
||||
args: Any, key: str, values: Dict[str, TElem], default_key: str
|
||||
) -> TElem:
|
||||
selected = args.get(key, default)
|
||||
selected = args.get(key, default_key)
|
||||
if selected in values:
|
||||
return values[selected]
|
||||
else:
|
||||
return values[default]
|
||||
return values[default_key]
|
||||
|
||||
|
||||
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||
|
@ -195,6 +209,8 @@ def load_config(file: str) -> Dict:
|
|||
return load_yaml(file)
|
||||
elif ext in [".json"]:
|
||||
return load_json(file)
|
||||
else:
|
||||
raise ValueError("unknown config file extension")
|
||||
|
||||
|
||||
def load_config_str(raw: str) -> Dict:
|
||||
|
|
|
@ -25,6 +25,7 @@ class WorkerContext:
|
|||
idle: "Value[bool]"
|
||||
timeout: float
|
||||
retries: int
|
||||
initial_retries: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -36,6 +37,8 @@ class WorkerContext:
|
|||
progress: "Queue[ProgressCommand]",
|
||||
active_pid: "Value[int]",
|
||||
idle: "Value[bool]",
|
||||
retries: int,
|
||||
timeout: float,
|
||||
):
|
||||
self.job = None
|
||||
self.name = name
|
||||
|
@ -47,12 +50,13 @@ class WorkerContext:
|
|||
self.active_pid = active_pid
|
||||
self.last_progress = None
|
||||
self.idle = idle
|
||||
self.timeout = 1.0
|
||||
self.retries = 3 # TODO: get from env
|
||||
self.initial_retries = retries
|
||||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
|
||||
def start(self, job: str) -> None:
|
||||
self.job = job
|
||||
self.retries = 3
|
||||
self.retries = self.initial_retries
|
||||
self.set_cancel(cancel=False)
|
||||
self.set_idle(idle=False)
|
||||
|
||||
|
@ -82,7 +86,7 @@ class WorkerContext:
|
|||
return 0
|
||||
|
||||
def get_progress_callback(self) -> ProgressCallback:
|
||||
from ..chain.base import ChainProgress
|
||||
from ..chain.pipeline import ChainProgress
|
||||
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
on_progress.step = step
|
||||
|
|
|
@ -86,15 +86,15 @@ class DevicePoolExecutor:
|
|||
self.logs = Queue(self.max_pending_per_worker)
|
||||
self.rlock = Lock()
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self, *args) -> None:
|
||||
self.create_health_worker()
|
||||
self.create_logger_worker()
|
||||
self.create_progress_worker()
|
||||
|
||||
for device in self.devices:
|
||||
self.create_device_worker(device)
|
||||
self.create_device_worker(device, *args)
|
||||
|
||||
def create_device_worker(self, device: DeviceParams) -> None:
|
||||
def create_device_worker(self, device: DeviceParams, *args) -> None:
|
||||
name = device.device
|
||||
|
||||
# always recreate queues
|
||||
|
@ -124,15 +124,17 @@ class DevicePoolExecutor:
|
|||
pending=self.pending[name],
|
||||
active_pid=current,
|
||||
idle=self.worker_idle[name],
|
||||
retries=self.server.worker_retries,
|
||||
timeout=self.progress_interval,
|
||||
)
|
||||
self.context[name] = context
|
||||
|
||||
worker = Process(
|
||||
name=f"onnx-web worker: {name}",
|
||||
target=worker_main,
|
||||
args=(context, self.server),
|
||||
args=(context, self.server, *args),
|
||||
daemon=True,
|
||||
)
|
||||
worker.daemon = True
|
||||
self.workers[name] = worker
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
|
|
|
@ -27,10 +27,14 @@ MEMORY_ERRORS = [
|
|||
]
|
||||
|
||||
|
||||
def worker_main(worker: WorkerContext, server: ServerContext):
|
||||
apply_patches(server)
|
||||
def worker_main(
|
||||
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
|
||||
):
|
||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||
|
||||
if patch:
|
||||
apply_patches(server)
|
||||
|
||||
logger.trace(
|
||||
"checking in from worker with providers: %s", get_available_providers()
|
||||
)
|
||||
|
@ -46,7 +50,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
|||
getpid(),
|
||||
worker.get_active(),
|
||||
)
|
||||
exit(EXIT_REPLACED)
|
||||
return exit(EXIT_REPLACED)
|
||||
|
||||
# wait briefly for the next job
|
||||
job = worker.pending.get(timeout=worker.timeout)
|
||||
|
@ -69,15 +73,15 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
|||
except KeyboardInterrupt:
|
||||
logger.debug("worker got keyboard interrupt")
|
||||
worker.fail()
|
||||
exit(EXIT_INTERRUPT)
|
||||
return exit(EXIT_INTERRUPT)
|
||||
except RetryException:
|
||||
logger.exception("retry error in worker, exiting")
|
||||
worker.fail()
|
||||
exit(EXIT_ERROR)
|
||||
return exit(EXIT_ERROR)
|
||||
except ValueError:
|
||||
logger.exception("value error in worker, exiting")
|
||||
worker.fail()
|
||||
exit(EXIT_ERROR)
|
||||
return exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
e_str = str(e)
|
||||
# restart the worker on memory errors
|
||||
|
@ -85,7 +89,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
|||
if e_mem in e_str:
|
||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||
worker.fail()
|
||||
exit(EXIT_MEMORY)
|
||||
return exit(EXIT_MEMORY)
|
||||
|
||||
# carry on for other errors
|
||||
logger.exception(
|
||||
|
|
|
@ -98,7 +98,7 @@
|
|||
"highresSteps": {
|
||||
"default": 0,
|
||||
"min": 1,
|
||||
"max": 200,
|
||||
"max": 500,
|
||||
"step": 1
|
||||
},
|
||||
"highresStrength": {
|
||||
|
@ -141,12 +141,6 @@
|
|||
"max": 4,
|
||||
"step": 1
|
||||
},
|
||||
"overlap": {
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 0.9,
|
||||
"step": 0.01
|
||||
},
|
||||
"pipeline": {
|
||||
"default": "",
|
||||
"keys": [
|
||||
|
@ -188,7 +182,7 @@
|
|||
"steps": {
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 200,
|
||||
"max": 300,
|
||||
"step": 1
|
||||
},
|
||||
"strength": {
|
||||
|
@ -197,21 +191,9 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"stride": {
|
||||
"default": 128,
|
||||
"min": 64,
|
||||
"max": 512,
|
||||
"step": 64
|
||||
},
|
||||
"tiledVAE": {
|
||||
"tiled_vae": {
|
||||
"default": false
|
||||
},
|
||||
"tiles": {
|
||||
"default": 512,
|
||||
"min": 128,
|
||||
"max": 2048,
|
||||
"step": 128
|
||||
},
|
||||
"tileOrder": {
|
||||
"default": "spiral",
|
||||
"keys": [
|
||||
|
@ -225,6 +207,18 @@
|
|||
"max": 1024,
|
||||
"step": 8
|
||||
},
|
||||
"unet_overlap": {
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 0.9,
|
||||
"step": 0.01
|
||||
},
|
||||
"unet_tile": {
|
||||
"default": 512,
|
||||
"min": 128,
|
||||
"max": 2048,
|
||||
"step": 128
|
||||
},
|
||||
"upscaleOrder": {
|
||||
"default": "correction-first",
|
||||
"keys": [
|
||||
|
@ -237,6 +231,18 @@
|
|||
"default": "",
|
||||
"keys": []
|
||||
},
|
||||
"vae_overlap": {
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 0.9,
|
||||
"step": 0.01
|
||||
},
|
||||
"vae_tile": {
|
||||
"default": 512,
|
||||
"min": 256,
|
||||
"max": 1024,
|
||||
"step": 128
|
||||
},
|
||||
"width": {
|
||||
"default": 512,
|
||||
"min": 128,
|
||||
|
|
|
@ -9,7 +9,9 @@ skip_glob = ["*/lpw.py"]
|
|||
[tool.mypy]
|
||||
# ignore_missing_imports = true
|
||||
exclude = [
|
||||
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
|
||||
"onnx_web.diffusers.pipelines.controlnet",
|
||||
"onnx_web.diffusers.pipelines.lpw",
|
||||
"onnx_web.diffusers.pipelines.pix2pix"
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
@ -27,8 +29,10 @@ module = [
|
|||
"compel",
|
||||
"controlnet_aux",
|
||||
"cv2",
|
||||
"debugpy",
|
||||
"diffusers",
|
||||
"diffusers.configuration_utils",
|
||||
"diffusers.image_processor",
|
||||
"diffusers.loaders",
|
||||
"diffusers.models.attention_processor",
|
||||
"diffusers.models.autoencoder_kl",
|
||||
|
@ -41,9 +45,10 @@ module = [
|
|||
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
||||
"diffusers.pipelines.onnx_utils",
|
||||
"diffusers.pipelines.paint_by_example",
|
||||
"diffusers.pipelines.pipeline_utils",
|
||||
"diffusers.pipelines.stable_diffusion",
|
||||
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
||||
"diffusers.pipeline_utils",
|
||||
"diffusers.pipelines.stable_diffusion_xl",
|
||||
"diffusers.schedulers",
|
||||
"diffusers.utils.logging",
|
||||
"facexlib.utils",
|
||||
|
@ -56,11 +61,17 @@ module = [
|
|||
"mediapipe",
|
||||
"onnxruntime",
|
||||
"onnxruntime.transformers.float16",
|
||||
"optimum.exporters.onnx",
|
||||
"optimum.onnxruntime",
|
||||
"optimum.onnxruntime.modeling_diffusion",
|
||||
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
|
||||
"optimum.pipelines.diffusers.pipeline_utils",
|
||||
"piexif",
|
||||
"piexif.helper",
|
||||
"realesrgan",
|
||||
"realesrgan.archs.srvgg_arch",
|
||||
"safetensors",
|
||||
"scipy",
|
||||
"timm.models.layers",
|
||||
"transformers",
|
||||
"win10toast"
|
||||
|
|
|
@ -46,17 +46,31 @@ $defs:
|
|||
patternProperties:
|
||||
"^[-_A-Za-z]+$":
|
||||
oneOf:
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: "null"
|
||||
|
||||
request_chain:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/request_stage"
|
||||
|
||||
request_defaults:
|
||||
type: object
|
||||
properties:
|
||||
txt2img:
|
||||
$ref: "#/$defs/image_params"
|
||||
img2img:
|
||||
$ref: "#/$defs/image_params"
|
||||
|
||||
type: object
|
||||
additionalProperties: False
|
||||
required: [stages]
|
||||
properties:
|
||||
defaults:
|
||||
$ref: "#/$defs/request_defaults"
|
||||
platform:
|
||||
type: string
|
||||
stages:
|
||||
$ref: "#/$defs/request_chain"
|
||||
|
|
|
@ -10,34 +10,53 @@ $defs:
|
|||
- type: number
|
||||
- type: string
|
||||
|
||||
lora_network:
|
||||
tensor_format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
|
||||
embedding_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
$ref: "#/$defs/tensor_format"
|
||||
label:
|
||||
type: string
|
||||
weight:
|
||||
type: number
|
||||
|
||||
textual_inversion_network:
|
||||
type: object
|
||||
required: [name, source]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
format:
|
||||
model:
|
||||
type: string
|
||||
enum: [concept, embeddings]
|
||||
label:
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
token:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: inversion # TODO: add embedding
|
||||
weight:
|
||||
type: number
|
||||
|
||||
lora_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
label:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
enum: [cloneofsimo, sd-scripts]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
tokens:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: lora
|
||||
weight:
|
||||
type: number
|
||||
|
||||
|
@ -46,8 +65,7 @@ $defs:
|
|||
required: [name, source]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
$ref: "#/$defs/tensor_format"
|
||||
half:
|
||||
type: boolean
|
||||
label:
|
||||
|
@ -85,7 +103,7 @@ $defs:
|
|||
inversions:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/textual_inversion_network"
|
||||
$ref: "#/$defs/embedding_network"
|
||||
loras:
|
||||
type: array
|
||||
items:
|
||||
|
@ -100,6 +118,7 @@ $defs:
|
|||
panorama,
|
||||
pix2pix,
|
||||
txt2img,
|
||||
txt2img-sdxl,
|
||||
upscale,
|
||||
]
|
||||
vae:
|
||||
|
@ -141,31 +160,6 @@ $defs:
|
|||
source:
|
||||
type: string
|
||||
|
||||
source_network:
|
||||
type: object
|
||||
required: [name, source, type]
|
||||
properties:
|
||||
format:
|
||||
type: string
|
||||
enum: [bin, ckpt, onnx, pt, pth, safetensors]
|
||||
model:
|
||||
type: string
|
||||
enum: [
|
||||
# inversion
|
||||
concept,
|
||||
embeddings,
|
||||
# lora
|
||||
cloneofsimo,
|
||||
sd-scripts
|
||||
]
|
||||
name:
|
||||
type: string
|
||||
source:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
enum: [inversion, lora]
|
||||
|
||||
translation:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
|
@ -193,7 +187,9 @@ properties:
|
|||
networks:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/source_network"
|
||||
oneOf:
|
||||
- $ref: "#/$defs/lora_network"
|
||||
- $ref: "#/$defs/embedding_network"
|
||||
sources:
|
||||
type: array
|
||||
items:
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
from argparse import ArgumentParser
|
||||
from onnx_web.convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||
from os import path
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import (
|
||||
convert_model_to_external_data,
|
||||
write_external_data_tensors,
|
||||
)
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
from logging import getLogger
|
||||
|
||||
from onnx_web.convert.utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--base", type=str)
|
||||
parser.add_argument("--dest", type=str)
|
||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||
parser.add_argument("--lora_models", nargs="+", type=str, default=[])
|
||||
parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(
|
||||
"merging %s with %s with weights: %s",
|
||||
args.lora_models,
|
||||
args.base,
|
||||
args.lora_weights,
|
||||
)
|
||||
|
||||
default_weight = 1.0 / len(args.lora_models)
|
||||
while len(args.lora_weights) < len(args.lora_models):
|
||||
args.lora_weights.append(default_weight)
|
||||
|
||||
blend_model = blend_loras(
|
||||
context,
|
||||
args.base,
|
||||
list(zip(args.lora_models, args.lora_weights)),
|
||||
args.type,
|
||||
)
|
||||
if args.dest is None or args.dest == "" or args.dest == ":load":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||
logger.info("saved external data for %s nodes", len(external_data))
|
||||
|
||||
external_names, external_values = zip(*external_data)
|
||||
opts = SessionOptions()
|
||||
opts.add_external_initializers(list(external_names), list(external_values))
|
||||
sess = InferenceSession(
|
||||
bare_model.SerializeToString(),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
logger.info(
|
||||
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
||||
)
|
||||
else:
|
||||
convert_model_to_external_data(
|
||||
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
||||
)
|
||||
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||
|
||||
with open(dest_file, "w+b") as model_file:
|
||||
model_file.write(bare_model.SerializeToString())
|
||||
|
||||
logger.info("successfully saved blended model: %s", dest_file)
|
||||
|
||||
check_model(dest_file)
|
||||
|
||||
logger.info("checked blended model")
|
BIN
api/scripts/test-refs/blend-512-muffin-white-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/blend-512-muffin-white-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/img2img-panorama-1024x768-pumpkin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-even-256-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-even-256-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-horizontal-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-panorama-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/outpaint-vertical-512-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/outpaint-vertical-512-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-panorama-1024x768-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-codeformer-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
BIN
api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png (Stored with Git LFS)
BIN
api/scripts/test-refs/upscale-sd-x4-gfpgan-2048-muffin-0.png (Stored with Git LFS)
Binary file not shown.
|
@ -30,6 +30,10 @@ FAST_TEST = 10
|
|||
SLOW_TEST = 25
|
||||
VERY_SLOW_TEST = 75
|
||||
|
||||
STRICT_TEST = 1e-4
|
||||
LOOSE_TEST = 1e-2
|
||||
VERY_LOOSE_TEST = 0.025
|
||||
|
||||
|
||||
def test_path(relpath: str) -> str:
|
||||
return path.join(path.dirname(__file__), relpath)
|
||||
|
@ -41,7 +45,7 @@ class TestCase:
|
|||
name: str,
|
||||
query: str,
|
||||
max_attempts: int = FAST_TEST,
|
||||
mse_threshold: float = 1e-4,
|
||||
mse_threshold: float = STRICT_TEST,
|
||||
source: Union[Image.Image, List[Image.Image]] = None,
|
||||
mask: Image.Image = None,
|
||||
) -> None:
|
||||
|
@ -65,6 +69,7 @@ TEST_DATA = [
|
|||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-deis",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis",
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-dpm",
|
||||
|
@ -73,10 +78,12 @@ TEST_DATA = [
|
|||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-heun",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun",
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v1-5-512-muffin-unipc",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi",
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v2-1-512-muffin",
|
||||
|
@ -84,7 +91,7 @@ TEST_DATA = [
|
|||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v2-1-768-muffin",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
|
@ -106,7 +113,7 @@ TEST_DATA = [
|
|||
),
|
||||
TestCase(
|
||||
"img2img-sd-v1-5-256-pumpkin",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256",
|
||||
source="txt2img-sd-v1-5-256-muffin-0",
|
||||
),
|
||||
TestCase(
|
||||
|
@ -130,7 +137,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.025,
|
||||
mse_threshold=VERY_LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"outpaint-vertical-512",
|
||||
|
@ -141,7 +148,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.010,
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"outpaint-horizontal-512",
|
||||
|
@ -152,7 +159,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=SLOW_TEST,
|
||||
mse_threshold=0.010,
|
||||
mse_threshold=LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x2-1024-muffin",
|
||||
|
@ -229,7 +236,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=VERY_SLOW_TEST,
|
||||
mse_threshold=0.025,
|
||||
mse_threshold=VERY_LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"outpaint-panorama-vertical-512",
|
||||
|
@ -240,7 +247,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=VERY_SLOW_TEST,
|
||||
mse_threshold=0.025,
|
||||
mse_threshold=VERY_LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"outpaint-panorama-horizontal-512",
|
||||
|
@ -251,7 +258,7 @@ TEST_DATA = [
|
|||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
mask="mask-black",
|
||||
max_attempts=VERY_SLOW_TEST,
|
||||
mse_threshold=0.025,
|
||||
mse_threshold=VERY_LOOSE_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x4-codeformer-2048-muffin",
|
||||
|
@ -260,6 +267,7 @@ TEST_DATA = [
|
|||
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x4-gfpgan-2048-muffin",
|
||||
|
@ -268,6 +276,7 @@ TEST_DATA = [
|
|||
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-swinir-x4-codeformer-2048-muffin",
|
||||
|
@ -276,6 +285,7 @@ TEST_DATA = [
|
|||
"&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-swinir-x4-gfpgan-2048-muffin",
|
||||
|
@ -284,6 +294,7 @@ TEST_DATA = [
|
|||
"&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-512-muffin-0",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"upscale-sd-x4-codeformer-2048-muffin",
|
||||
|
@ -305,18 +316,18 @@ TEST_DATA = [
|
|||
),
|
||||
TestCase(
|
||||
"txt2img-panorama-1024x768-muffin",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true",
|
||||
max_attempts=VERY_SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"img2img-panorama-1024x768-pumpkin",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true",
|
||||
"img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true",
|
||||
source="txt2img-panorama-1024x768-muffin-0",
|
||||
max_attempts=VERY_SLOW_TEST,
|
||||
),
|
||||
TestCase(
|
||||
"txt2img-sd-v1-5-tall-muffin",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768",
|
||||
"txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768",
|
||||
),
|
||||
TestCase(
|
||||
"upscale-resrgan-x4-tall-muffin",
|
||||
|
@ -325,6 +336,7 @@ TEST_DATA = [
|
|||
"&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0"
|
||||
),
|
||||
source="txt2img-sd-v1-5-tall-muffin-0",
|
||||
max_attempts=SLOW_TEST,
|
||||
),
|
||||
# TODO: non-square controlnet
|
||||
]
|
||||
|
@ -335,6 +347,39 @@ class TestError(Exception):
|
|||
return super().__str__()
|
||||
|
||||
|
||||
class TestResult:
|
||||
error: Optional[str]
|
||||
mse: Optional[float]
|
||||
name: str
|
||||
passed: bool
|
||||
|
||||
def __init__(self, name: str, error = None, passed = True, mse = None) -> None:
|
||||
self.error = error
|
||||
self.mse = mse
|
||||
self.name = name
|
||||
self.passed = passed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.passed:
|
||||
if self.mse is not None:
|
||||
return f"{self.name} ({self.mse})"
|
||||
else:
|
||||
return self.name
|
||||
else:
|
||||
if self.mse is not None:
|
||||
return f"{self.name}: {self.error} ({self.mse})"
|
||||
else:
|
||||
return f"{self.name}: {self.error}"
|
||||
|
||||
@classmethod
|
||||
def passed(self, name: str, mse = None):
|
||||
return TestResult(name, mse=mse)
|
||||
|
||||
@classmethod
|
||||
def failed(self, name: str, error: str, mse = None):
|
||||
return TestResult(name, error=error, mse=mse, passed=False)
|
||||
|
||||
|
||||
def parse_args(args: List[str]):
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web release tests",
|
||||
|
@ -441,14 +486,14 @@ def run_test(
|
|||
host: str,
|
||||
test: TestCase,
|
||||
mse_mult: float = 1.0,
|
||||
) -> bool:
|
||||
) -> TestResult:
|
||||
"""
|
||||
Generate an image, wait for it to be ready, and calculate the MSE from the reference.
|
||||
"""
|
||||
|
||||
keys = generate_images(host, test)
|
||||
if keys is None:
|
||||
raise ValueError("could not generate image")
|
||||
return TestResult.failed(test.name, "could not generate image")
|
||||
|
||||
ready = False
|
||||
for attempt in tqdm(range(test.max_attempts)):
|
||||
|
@ -461,13 +506,13 @@ def run_test(
|
|||
sleep(6)
|
||||
|
||||
if not ready:
|
||||
raise ValueError("image was not ready in time")
|
||||
return TestResult.failed(test.name, "image was not ready in time")
|
||||
|
||||
results = download_images(host, keys)
|
||||
if results is None:
|
||||
raise ValueError("could not download image")
|
||||
if results is None or len(results) == 0:
|
||||
return TestResult.failed(test.name, "could not download image")
|
||||
|
||||
passed = True
|
||||
passed = False
|
||||
for i in range(len(results)):
|
||||
result = results[i]
|
||||
result.save(test_path(path.join("test-results", f"{test.name}-{i}.png")))
|
||||
|
@ -476,14 +521,19 @@ def run_test(
|
|||
ref = Image.open(ref_name) if path.exists(ref_name) else None
|
||||
|
||||
mse = find_mse(result, ref)
|
||||
threshold = test.mse_threshold * mse_mult
|
||||
|
||||
if mse < (test.mse_threshold * mse_mult):
|
||||
logger.info("MSE within threshold: %.5f < %.5f", mse, test.mse_threshold)
|
||||
if mse < threshold:
|
||||
logger.info("MSE within threshold: %.5f < %.5f", mse, threshold)
|
||||
passed = True
|
||||
else:
|
||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, test.mse_threshold)
|
||||
passed = False
|
||||
logger.warning("MSE above threshold: %.5f > %.5f", mse, threshold)
|
||||
return TestResult.failed(test.name, error="MSE above threshold", mse=mse)
|
||||
|
||||
return passed
|
||||
if passed:
|
||||
return TestResult.passed(test.name)
|
||||
else:
|
||||
return TestResult.failed(test.name, "no images tested")
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -504,24 +554,26 @@ def main():
|
|||
passed = []
|
||||
failed = []
|
||||
for test in tests:
|
||||
test_passed = False
|
||||
result = None
|
||||
|
||||
for _i in range(3):
|
||||
try:
|
||||
logger.info("starting test: %s", test.name)
|
||||
if run_test(args.host, test, mse_mult=args.mse):
|
||||
result = run_test(args.host, test, mse_mult=args.mse)
|
||||
if result.passed:
|
||||
logger.info("test passed: %s", test.name)
|
||||
test_passed = True
|
||||
break
|
||||
else:
|
||||
logger.warning("test failed: %s", test.name)
|
||||
except Exception:
|
||||
logger.exception("error running test for %s", test.name)
|
||||
result = TestResult.failed(test.name, "TODO: exception message")
|
||||
|
||||
if test_passed:
|
||||
passed.append(test.name)
|
||||
if result is not None:
|
||||
if result.passed:
|
||||
passed.append(result)
|
||||
else:
|
||||
failed.append(test.name)
|
||||
failed.append(result)
|
||||
|
||||
logger.info("%s of %s tests passed", len(passed), len(tests))
|
||||
failed = list(set(failed))
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.pipeline import ChainProgress
|
||||
|
||||
|
||||
class ChainProgressTests(unittest.TestCase):
|
||||
def test_accumulate_with_reset(self):
|
||||
def parent(step, timestep, latents):
|
||||
pass
|
||||
|
||||
progress = ChainProgress(parent)
|
||||
progress(5, 1, None)
|
||||
progress(0, 1, None)
|
||||
progress(5, 1, None)
|
||||
|
||||
self.assertEqual(progress.get_total(), 10)
|
||||
|
||||
def test_start_value(self):
|
||||
def parent(step, timestep, latents):
|
||||
pass
|
||||
|
||||
progress = ChainProgress(parent, 5)
|
||||
self.assertEqual(progress.get_total(), 5)
|
||||
|
||||
progress(10, 1, None)
|
||||
self.assertEqual(progress.get_total(), 10)
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_grid import BlendGridStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
|
||||
class BlendGridStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendGridStage()
|
||||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
]
|
||||
)
|
||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||
|
||||
self.assertEqual(len(result), 5)
|
||||
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
|
@ -0,0 +1,47 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import ImageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||
|
||||
|
||||
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_stage(self):
|
||||
stage = BlendImg2ImgStage()
|
||||
params = ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"euler-a",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
0.1,
|
||||
)
|
||||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
)
|
||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
|
|
@ -0,0 +1,23 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
|
||||
class BlendLinearStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendLinearStage()
|
||||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
)
|
||||
stage_source = Image.new("RGB", (64, 64), "white")
|
||||
result = stage.run(
|
||||
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
|
@ -0,0 +1,25 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_mask import BlendMaskStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
|
||||
|
||||
class BlendMaskStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = BlendMaskStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
stage_mask=Image.new("RGBA", (64, 64)),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,46 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.server.hacks import apply_patches
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import (
|
||||
TEST_MODEL_CORRECTION_CODEFORMER,
|
||||
test_device,
|
||||
test_needs_models,
|
||||
)
|
||||
|
||||
|
||||
class CorrectCodeformerStageTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER])
|
||||
def test_empty(self):
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
apply_patches(server)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
0.1,
|
||||
)
|
||||
stage = CorrectCodeformerStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
worker,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,44 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.server.hacks import apply_patches
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
|
||||
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||||
|
||||
|
||||
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
apply_patches(server)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
0.1,
|
||||
)
|
||||
stage = CorrectGFPGANStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
worker,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,24 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.reduce_crop import ReduceCropStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class ReduceCropStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = ReduceCropStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,28 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class ReduceThumbnailStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage_source = Image.new("RGB", (64, 64))
|
||||
stage = ReduceThumbnailStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
stage_source=stage_source,
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue