1
0
Fork 0

feat(api): parse metadata from input images

This commit is contained in:
Sean Sube 2024-01-13 10:01:50 -06:00
parent 065df23902
commit a5fe52d2a2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 183 additions and 28 deletions

View File

@ -5,7 +5,7 @@ from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.edit_metadata import EditMetadataStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams
from .pipeline import ChainPipeline
logger = getLogger(__name__)
@ -70,7 +70,7 @@ def stage_highres(
# add highres parameters to the image metadata
chain.stage(
EditMetadataStage(),
stage.with_args(outscale=1),
stage.with_args(outscale=1, tile_size=SizeChart.max),
highres=highres,
replace_params=params,
)

View File

@ -1,6 +1,7 @@
from json import dumps
from logging import getLogger
from os import path
from re import compile
from typing import Any, List, Optional, Tuple
import numpy as np
@ -10,10 +11,12 @@ from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext
from ..server.load import get_extra_hashes
from ..utils import hash_file
from ..utils import hash_file, load_config_str
logger = getLogger(__name__)
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
class NetworkMetadata:
name: str
@ -238,6 +241,70 @@ class ImageMetadata:
return json
@staticmethod
def from_exif(input: str) -> "ImageMetadata":
lines = input.splitlines()
prompt, maybe_negative, *rest = lines
# process negative prompt or put that line back into rest
if maybe_negative.startswith("Negative prompt:"):
negative_prompt = maybe_negative[len("Negative prompt:") :]
negative_prompt = negative_prompt.strip()
else:
rest.insert(0, maybe_negative)
negative_prompt = None
rest = " ".join(rest)
other_params = rest.split(",")
# process other params
params = {}
size = None
for param in other_params:
key, value = param.split(":")
key = key.strip().lower()
value = value.strip()
if key == "size":
width, height = value.split("x")
width = int(width.strip())
height = int(height.strip())
size = Size(width, height)
elif value.isdecimal():
value = int(value)
elif FLOAT_PATTERN.match(value) is not None:
value = float(value)
params[key] = value
params = ImageParams(
"TODO",
"txt2img", # TODO: can this be detected?
params["sampler"],
prompt,
params["cfg scale"],
params["steps"],
params["seed"],
negative_prompt,
)
return ImageMetadata(params, size)
@staticmethod
def from_json(input: str) -> "ImageMetadata":
data = load_config_str(input)
# TODO: enforce schema
return ImageMetadata(
data["params"],
data["input_size"],
data.get("upscale", None),
data.get("border", None),
data.get("highres", None),
data.get("inversions", None),
data.get("loras", None),
data.get("models", None),
)
ERROR_NO_METADATA = "metadata must be provided"

View File

@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
from ..chain.result import ImageMetadata, StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image, save_result
from ..output import read_metadata, save_image, save_result
from ..params import (
Border,
HighresParams,
@ -64,6 +64,31 @@ def add_safety_stage(
)
def add_thumbnail_output(
server: ServerContext,
images: StageResult,
params: ImageParams,
) -> None:
"""
Add a thumbnail image to the output, if requested.
TODO: This should really be a stage.
"""
result_size = images.size()
if (
params.thumbnail
and len(images) > 0
and (
result_size.width > server.thumbnail_size
or result_size.height > server.thumbnail_size
)
):
cover = images.as_images()[0]
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail, images.metadata[0])
def run_txt2img_pipeline(
worker: WorkerContext,
server: ServerContext,
@ -131,22 +156,7 @@ def run_txt2img_pipeline(
worker, server, params, StageResult.empty(), callback=progress, latents=latents
)
# add a thumbnail, if requested
result_size = images.size()
if (
params.thumbnail
and len(images) > 0
and (
result_size.width > server.thumbnail_size
or result_size.height > server.thumbnail_size
)
):
cover = images.as_images()[0]
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail, images.metadata[0])
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job)
# clean up
@ -231,19 +241,24 @@ def run_img2img_pipeline(
add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and append the filtered source
progress = worker.get_progress_callback(reset=True)
images = chain(
worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
input_result, # terrible naming, I know
callback=progress,
)
if source_filter is not None and source_filter != "none":
images.push_image(source, ImageMetadata.unknown_image())
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job)
# clean up
@ -406,6 +421,10 @@ def run_inpaint_pipeline(
add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback(reset=True)
@ -413,13 +432,13 @@ def run_inpaint_pipeline(
worker,
server,
params,
StageResult(
images=[source], metadata=[ImageMetadata.unknown_image()]
), # TODO: load metadata from source image
input_result,
callback=progress,
latents=latents,
)
add_thumbnail_output(server, images, params)
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
if full_res_inpaint:
if is_debug():
@ -488,16 +507,21 @@ def run_upscale_pipeline(
add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(
worker,
server,
params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
input_result,
callback=progress,
)
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job)
# clean up
@ -543,18 +567,23 @@ def run_blend_pipeline(
add_safety_stage(server, chain)
# prep inputs
input_metadata = [
read_metadata(source) or ImageMetadata.unknown_image() for source in sources
]
input_result = StageResult(images=sources, metadata=input_metadata)
# run and save
progress = worker.get_progress_callback(reset=True)
images = chain(
worker,
server,
params,
StageResult(
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
),
input_result,
callback=progress,
)
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job)
# clean up

View File

@ -145,3 +145,19 @@ def save_metadata(
f.write(dumps(json))
logger.debug("saved image params to: %s", path)
return path
def read_metadata(
image: Image.Image,
) -> Optional[ImageMetadata]:
exif_data = image._getexif()
if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web":
return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote])
if ExifIFD.UserComment in exif_data:
return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment])
# this could return ImageMetadata.unknown_image(), but that would not indicate whether the input
# had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()`
return None

View File

@ -0,0 +1,43 @@
import unittest
from onnx_web.chain.result import ImageMetadata
class ImageMetadataTests(unittest.TestCase):
def test_image_metadata(self):
pass
def test_from_exif_normal(self):
exif_data = """test prompt
Negative prompt: negative prompt
Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5
"""
metadata = ImageMetadata.from_exif(exif_data)
self.assertEqual(metadata.params.prompt, "test prompt")
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
self.assertEqual(metadata.params.scheduler, "ddim")
self.assertEqual(metadata.params.cfg, 4.0)
self.assertEqual(metadata.params.steps, 30)
self.assertEqual(metadata.params.seed, 5)
def test_from_exif_split(self):
exif_data = """test prompt
Negative prompt: negative prompt
Sampler: ddim,
CFG scale: 4.0,
Steps: 30, Seed: 5
"""
metadata = ImageMetadata.from_exif(exif_data)
self.assertEqual(metadata.params.prompt, "test prompt")
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
self.assertEqual(metadata.params.scheduler, "ddim")
self.assertEqual(metadata.params.cfg, 4.0)
self.assertEqual(metadata.params.steps, 30)
self.assertEqual(metadata.params.seed, 5)
class StageResultTests(unittest.TestCase):
def test_stage_result(self):
pass