From a5fe52d2a269d12a871d8e2446276a2532b67749 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 13 Jan 2024 10:01:50 -0600 Subject: [PATCH] feat(api): parse metadata from input images --- api/onnx_web/chain/highres.py | 4 +- api/onnx_web/chain/result.py | 69 ++++++++++++++++++++++++++++- api/onnx_web/diffusers/run.py | 79 +++++++++++++++++++++++----------- api/onnx_web/output.py | 16 +++++++ api/tests/chain/test_result.py | 43 ++++++++++++++++++ 5 files changed, 183 insertions(+), 28 deletions(-) create mode 100644 api/tests/chain/test_result.py diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index b962e503..7e47ff1f 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -5,7 +5,7 @@ from ..chain.blend_img2img import BlendImg2ImgStage from ..chain.edit_metadata import EditMetadataStage from ..chain.upscale import stage_upscale_correction from ..chain.upscale_simple import UpscaleSimpleStage -from ..params import HighresParams, ImageParams, StageParams, UpscaleParams +from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams from .pipeline import ChainPipeline logger = getLogger(__name__) @@ -70,7 +70,7 @@ def stage_highres( # add highres parameters to the image metadata chain.stage( EditMetadataStage(), - stage.with_args(outscale=1), + stage.with_args(outscale=1, tile_size=SizeChart.max), highres=highres, replace_params=params, ) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index d1277177..4e357e57 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,6 +1,7 @@ from json import dumps from logging import getLogger from os import path +from re import compile from typing import Any, List, Optional, Tuple import numpy as np @@ -10,10 +11,12 @@ from ..convert.utils import resolve_tensor from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..server.context import ServerContext from ..server.load import get_extra_hashes -from ..utils import hash_file +from ..utils import hash_file, load_config_str logger = getLogger(__name__) +FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") + class NetworkMetadata: name: str @@ -238,6 +241,70 @@ class ImageMetadata: return json + @staticmethod + def from_exif(input: str) -> "ImageMetadata": + lines = input.splitlines() + prompt, maybe_negative, *rest = lines + + # process negative prompt or put that line back into rest + if maybe_negative.startswith("Negative prompt:"): + negative_prompt = maybe_negative[len("Negative prompt:") :] + negative_prompt = negative_prompt.strip() + else: + rest.insert(0, maybe_negative) + negative_prompt = None + + rest = " ".join(rest) + other_params = rest.split(",") + + # process other params + params = {} + size = None + for param in other_params: + key, value = param.split(":") + key = key.strip().lower() + value = value.strip() + + if key == "size": + width, height = value.split("x") + width = int(width.strip()) + height = int(height.strip()) + size = Size(width, height) + elif value.isdecimal(): + value = int(value) + elif FLOAT_PATTERN.match(value) is not None: + value = float(value) + + params[key] = value + + params = ImageParams( + "TODO", + "txt2img", # TODO: can this be detected? + params["sampler"], + prompt, + params["cfg scale"], + params["steps"], + params["seed"], + negative_prompt, + ) + return ImageMetadata(params, size) + + @staticmethod + def from_json(input: str) -> "ImageMetadata": + data = load_config_str(input) + # TODO: enforce schema + + return ImageMetadata( + data["params"], + data["input_size"], + data.get("upscale", None), + data.get("border", None), + data.get("highres", None), + data.get("inversions", None), + data.get("loras", None), + data.get("models", None), + ) + ERROR_NO_METADATA = "metadata must be provided" diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 6f145d9c..33b02c57 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -16,7 +16,7 @@ from ..chain.highres import stage_highres from ..chain.result import ImageMetadata, StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image -from ..output import save_image, save_result +from ..output import read_metadata, save_image, save_result from ..params import ( Border, HighresParams, @@ -64,6 +64,31 @@ def add_safety_stage( ) +def add_thumbnail_output( + server: ServerContext, + images: StageResult, + params: ImageParams, +) -> None: + """ + Add a thumbnail image to the output, if requested. + TODO: This should really be a stage. + """ + result_size = images.size() + if ( + params.thumbnail + and len(images) > 0 + and ( + result_size.width > server.thumbnail_size + or result_size.height > server.thumbnail_size + ) + ): + cover = images.as_images()[0] + thumbnail = cover.copy() + thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) + + images.insert_image(0, thumbnail, images.metadata[0]) + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, @@ -131,22 +156,7 @@ def run_txt2img_pipeline( worker, server, params, StageResult.empty(), callback=progress, latents=latents ) - # add a thumbnail, if requested - result_size = images.size() - if ( - params.thumbnail - and len(images) > 0 - and ( - result_size.width > server.thumbnail_size - or result_size.height > server.thumbnail_size - ) - ): - cover = images.as_images()[0] - thumbnail = cover.copy() - thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) - - images.insert_image(0, thumbnail, images.metadata[0]) - + add_thumbnail_output(server, images, params) save_result(server, images, worker.job) # clean up @@ -231,19 +241,24 @@ def run_img2img_pipeline( add_safety_stage(server, chain) + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and append the filtered source progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, - StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + input_result, # terrible naming, I know callback=progress, ) if source_filter is not None and source_filter != "none": images.push_image(source, ImageMetadata.unknown_image()) + add_thumbnail_output(server, images, params) save_result(server, images, worker.job) # clean up @@ -406,6 +421,10 @@ def run_inpaint_pipeline( add_safety_stage(server, chain) + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback(reset=True) @@ -413,13 +432,13 @@ def run_inpaint_pipeline( worker, server, params, - StageResult( - images=[source], metadata=[ImageMetadata.unknown_image()] - ), # TODO: load metadata from source image + input_result, callback=progress, latents=latents, ) + add_thumbnail_output(server, images, params) + for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)): if full_res_inpaint: if is_debug(): @@ -488,16 +507,21 @@ def run_upscale_pipeline( add_safety_stage(server, chain) + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and save progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, - StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + input_result, callback=progress, ) + add_thumbnail_output(server, images, params) save_result(server, images, worker.job) # clean up @@ -543,18 +567,23 @@ def run_blend_pipeline( add_safety_stage(server, chain) + # prep inputs + input_metadata = [ + read_metadata(source) or ImageMetadata.unknown_image() for source in sources + ] + input_result = StageResult(images=sources, metadata=input_metadata) + # run and save progress = worker.get_progress_callback(reset=True) images = chain( worker, server, params, - StageResult( - images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources) - ), + input_result, callback=progress, ) + add_thumbnail_output(server, images, params) save_result(server, images, worker.job) # clean up diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 9404df16..aca4b92a 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -145,3 +145,19 @@ def save_metadata( f.write(dumps(json)) logger.debug("saved image params to: %s", path) return path + + +def read_metadata( + image: Image.Image, +) -> Optional[ImageMetadata]: + exif_data = image._getexif() + + if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web": + return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote]) + + if ExifIFD.UserComment in exif_data: + return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment]) + + # this could return ImageMetadata.unknown_image(), but that would not indicate whether the input + # had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()` + return None diff --git a/api/tests/chain/test_result.py b/api/tests/chain/test_result.py new file mode 100644 index 00000000..0d062d82 --- /dev/null +++ b/api/tests/chain/test_result.py @@ -0,0 +1,43 @@ +import unittest + +from onnx_web.chain.result import ImageMetadata + + +class ImageMetadataTests(unittest.TestCase): + def test_image_metadata(self): + pass + + def test_from_exif_normal(self): + exif_data = """test prompt +Negative prompt: negative prompt +Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5 +""" + + metadata = ImageMetadata.from_exif(exif_data) + self.assertEqual(metadata.params.prompt, "test prompt") + self.assertEqual(metadata.params.negative_prompt, "negative prompt") + self.assertEqual(metadata.params.scheduler, "ddim") + self.assertEqual(metadata.params.cfg, 4.0) + self.assertEqual(metadata.params.steps, 30) + self.assertEqual(metadata.params.seed, 5) + + def test_from_exif_split(self): + exif_data = """test prompt +Negative prompt: negative prompt +Sampler: ddim, +CFG scale: 4.0, +Steps: 30, Seed: 5 +""" + + metadata = ImageMetadata.from_exif(exif_data) + self.assertEqual(metadata.params.prompt, "test prompt") + self.assertEqual(metadata.params.negative_prompt, "negative prompt") + self.assertEqual(metadata.params.scheduler, "ddim") + self.assertEqual(metadata.params.cfg, 4.0) + self.assertEqual(metadata.params.steps, 30) + self.assertEqual(metadata.params.seed, 5) + + +class StageResultTests(unittest.TestCase): + def test_stage_result(self): + pass