1
0
Fork 0

feat(api): parse metadata from input images

This commit is contained in:
Sean Sube 2024-01-13 10:01:50 -06:00
parent 065df23902
commit a5fe52d2a2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 183 additions and 28 deletions

View File

@ -5,7 +5,7 @@ from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.edit_metadata import EditMetadataStage from ..chain.edit_metadata import EditMetadataStage
from ..chain.upscale import stage_upscale_correction from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams
from .pipeline import ChainPipeline from .pipeline import ChainPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -70,7 +70,7 @@ def stage_highres(
# add highres parameters to the image metadata # add highres parameters to the image metadata
chain.stage( chain.stage(
EditMetadataStage(), EditMetadataStage(),
stage.with_args(outscale=1), stage.with_args(outscale=1, tile_size=SizeChart.max),
highres=highres, highres=highres,
replace_params=params, replace_params=params,
) )

View File

@ -1,6 +1,7 @@
from json import dumps from json import dumps
from logging import getLogger from logging import getLogger
from os import path from os import path
from re import compile
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
@ -10,10 +11,12 @@ from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext from ..server.context import ServerContext
from ..server.load import get_extra_hashes from ..server.load import get_extra_hashes
from ..utils import hash_file from ..utils import hash_file, load_config_str
logger = getLogger(__name__) logger = getLogger(__name__)
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
class NetworkMetadata: class NetworkMetadata:
name: str name: str
@ -238,6 +241,70 @@ class ImageMetadata:
return json return json
@staticmethod
def from_exif(input: str) -> "ImageMetadata":
lines = input.splitlines()
prompt, maybe_negative, *rest = lines
# process negative prompt or put that line back into rest
if maybe_negative.startswith("Negative prompt:"):
negative_prompt = maybe_negative[len("Negative prompt:") :]
negative_prompt = negative_prompt.strip()
else:
rest.insert(0, maybe_negative)
negative_prompt = None
rest = " ".join(rest)
other_params = rest.split(",")
# process other params
params = {}
size = None
for param in other_params:
key, value = param.split(":")
key = key.strip().lower()
value = value.strip()
if key == "size":
width, height = value.split("x")
width = int(width.strip())
height = int(height.strip())
size = Size(width, height)
elif value.isdecimal():
value = int(value)
elif FLOAT_PATTERN.match(value) is not None:
value = float(value)
params[key] = value
params = ImageParams(
"TODO",
"txt2img", # TODO: can this be detected?
params["sampler"],
prompt,
params["cfg scale"],
params["steps"],
params["seed"],
negative_prompt,
)
return ImageMetadata(params, size)
@staticmethod
def from_json(input: str) -> "ImageMetadata":
data = load_config_str(input)
# TODO: enforce schema
return ImageMetadata(
data["params"],
data["input_size"],
data.get("upscale", None),
data.get("border", None),
data.get("highres", None),
data.get("inversions", None),
data.get("loras", None),
data.get("models", None),
)
ERROR_NO_METADATA = "metadata must be provided" ERROR_NO_METADATA = "metadata must be provided"

View File

@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
from ..chain.result import ImageMetadata, StageResult from ..chain.result import ImageMetadata, StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image, save_result from ..output import read_metadata, save_image, save_result
from ..params import ( from ..params import (
Border, Border,
HighresParams, HighresParams,
@ -64,6 +64,31 @@ def add_safety_stage(
) )
def add_thumbnail_output(
server: ServerContext,
images: StageResult,
params: ImageParams,
) -> None:
"""
Add a thumbnail image to the output, if requested.
TODO: This should really be a stage.
"""
result_size = images.size()
if (
params.thumbnail
and len(images) > 0
and (
result_size.width > server.thumbnail_size
or result_size.height > server.thumbnail_size
)
):
cover = images.as_images()[0]
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail, images.metadata[0])
def run_txt2img_pipeline( def run_txt2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
@ -131,22 +156,7 @@ def run_txt2img_pipeline(
worker, server, params, StageResult.empty(), callback=progress, latents=latents worker, server, params, StageResult.empty(), callback=progress, latents=latents
) )
# add a thumbnail, if requested add_thumbnail_output(server, images, params)
result_size = images.size()
if (
params.thumbnail
and len(images) > 0
and (
result_size.width > server.thumbnail_size
or result_size.height > server.thumbnail_size
)
):
cover = images.as_images()[0]
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert_image(0, thumbnail, images.metadata[0])
save_result(server, images, worker.job) save_result(server, images, worker.job)
# clean up # clean up
@ -231,19 +241,24 @@ def run_img2img_pipeline(
add_safety_stage(server, chain) add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, worker,
server, server,
params, params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), input_result, # terrible naming, I know
callback=progress, callback=progress,
) )
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
images.push_image(source, ImageMetadata.unknown_image()) images.push_image(source, ImageMetadata.unknown_image())
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job) save_result(server, images, worker.job)
# clean up # clean up
@ -406,6 +421,10 @@ def run_inpaint_pipeline(
add_safety_stage(server, chain) add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
@ -413,13 +432,13 @@ def run_inpaint_pipeline(
worker, worker,
server, server,
params, params,
StageResult( input_result,
images=[source], metadata=[ImageMetadata.unknown_image()]
), # TODO: load metadata from source image
callback=progress, callback=progress,
latents=latents, latents=latents,
) )
add_thumbnail_output(server, images, params)
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)): for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
if full_res_inpaint: if full_res_inpaint:
if is_debug(): if is_debug():
@ -488,16 +507,21 @@ def run_upscale_pipeline(
add_safety_stage(server, chain) add_safety_stage(server, chain)
# prep inputs
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
input_result = StageResult(images=[source], metadata=[input_metadata])
# run and save # run and save
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, worker,
server, server,
params, params,
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), input_result,
callback=progress, callback=progress,
) )
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job) save_result(server, images, worker.job)
# clean up # clean up
@ -543,18 +567,23 @@ def run_blend_pipeline(
add_safety_stage(server, chain) add_safety_stage(server, chain)
# prep inputs
input_metadata = [
read_metadata(source) or ImageMetadata.unknown_image() for source in sources
]
input_result = StageResult(images=sources, metadata=input_metadata)
# run and save # run and save
progress = worker.get_progress_callback(reset=True) progress = worker.get_progress_callback(reset=True)
images = chain( images = chain(
worker, worker,
server, server,
params, params,
StageResult( input_result,
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
),
callback=progress, callback=progress,
) )
add_thumbnail_output(server, images, params)
save_result(server, images, worker.job) save_result(server, images, worker.job)
# clean up # clean up

View File

@ -145,3 +145,19 @@ def save_metadata(
f.write(dumps(json)) f.write(dumps(json))
logger.debug("saved image params to: %s", path) logger.debug("saved image params to: %s", path)
return path return path
def read_metadata(
image: Image.Image,
) -> Optional[ImageMetadata]:
exif_data = image._getexif()
if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web":
return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote])
if ExifIFD.UserComment in exif_data:
return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment])
# this could return ImageMetadata.unknown_image(), but that would not indicate whether the input
# had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()`
return None

View File

@ -0,0 +1,43 @@
import unittest
from onnx_web.chain.result import ImageMetadata
class ImageMetadataTests(unittest.TestCase):
def test_image_metadata(self):
pass
def test_from_exif_normal(self):
exif_data = """test prompt
Negative prompt: negative prompt
Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5
"""
metadata = ImageMetadata.from_exif(exif_data)
self.assertEqual(metadata.params.prompt, "test prompt")
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
self.assertEqual(metadata.params.scheduler, "ddim")
self.assertEqual(metadata.params.cfg, 4.0)
self.assertEqual(metadata.params.steps, 30)
self.assertEqual(metadata.params.seed, 5)
def test_from_exif_split(self):
exif_data = """test prompt
Negative prompt: negative prompt
Sampler: ddim,
CFG scale: 4.0,
Steps: 30, Seed: 5
"""
metadata = ImageMetadata.from_exif(exif_data)
self.assertEqual(metadata.params.prompt, "test prompt")
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
self.assertEqual(metadata.params.scheduler, "ddim")
self.assertEqual(metadata.params.cfg, 4.0)
self.assertEqual(metadata.params.steps, 30)
self.assertEqual(metadata.params.seed, 5)
class StageResultTests(unittest.TestCase):
def test_stage_result(self):
pass