feat(api): parse metadata from input images
This commit is contained in:
parent
065df23902
commit
a5fe52d2a2
|
@ -5,7 +5,7 @@ from ..chain.blend_img2img import BlendImg2ImgStage
|
|||
from ..chain.edit_metadata import EditMetadataStage
|
||||
from ..chain.upscale import stage_upscale_correction
|
||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .pipeline import ChainPipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -70,7 +70,7 @@ def stage_highres(
|
|||
# add highres parameters to the image metadata
|
||||
chain.stage(
|
||||
EditMetadataStage(),
|
||||
stage.with_args(outscale=1),
|
||||
stage.with_args(outscale=1, tile_size=SizeChart.max),
|
||||
highres=highres,
|
||||
replace_params=params,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from json import dumps
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from re import compile
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
@ -10,10 +11,12 @@ from ..convert.utils import resolve_tensor
|
|||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||
from ..server.context import ServerContext
|
||||
from ..server.load import get_extra_hashes
|
||||
from ..utils import hash_file
|
||||
from ..utils import hash_file, load_config_str
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
|
||||
|
||||
|
||||
class NetworkMetadata:
|
||||
name: str
|
||||
|
@ -238,6 +241,70 @@ class ImageMetadata:
|
|||
|
||||
return json
|
||||
|
||||
@staticmethod
|
||||
def from_exif(input: str) -> "ImageMetadata":
|
||||
lines = input.splitlines()
|
||||
prompt, maybe_negative, *rest = lines
|
||||
|
||||
# process negative prompt or put that line back into rest
|
||||
if maybe_negative.startswith("Negative prompt:"):
|
||||
negative_prompt = maybe_negative[len("Negative prompt:") :]
|
||||
negative_prompt = negative_prompt.strip()
|
||||
else:
|
||||
rest.insert(0, maybe_negative)
|
||||
negative_prompt = None
|
||||
|
||||
rest = " ".join(rest)
|
||||
other_params = rest.split(",")
|
||||
|
||||
# process other params
|
||||
params = {}
|
||||
size = None
|
||||
for param in other_params:
|
||||
key, value = param.split(":")
|
||||
key = key.strip().lower()
|
||||
value = value.strip()
|
||||
|
||||
if key == "size":
|
||||
width, height = value.split("x")
|
||||
width = int(width.strip())
|
||||
height = int(height.strip())
|
||||
size = Size(width, height)
|
||||
elif value.isdecimal():
|
||||
value = int(value)
|
||||
elif FLOAT_PATTERN.match(value) is not None:
|
||||
value = float(value)
|
||||
|
||||
params[key] = value
|
||||
|
||||
params = ImageParams(
|
||||
"TODO",
|
||||
"txt2img", # TODO: can this be detected?
|
||||
params["sampler"],
|
||||
prompt,
|
||||
params["cfg scale"],
|
||||
params["steps"],
|
||||
params["seed"],
|
||||
negative_prompt,
|
||||
)
|
||||
return ImageMetadata(params, size)
|
||||
|
||||
@staticmethod
|
||||
def from_json(input: str) -> "ImageMetadata":
|
||||
data = load_config_str(input)
|
||||
# TODO: enforce schema
|
||||
|
||||
return ImageMetadata(
|
||||
data["params"],
|
||||
data["input_size"],
|
||||
data.get("upscale", None),
|
||||
data.get("border", None),
|
||||
data.get("highres", None),
|
||||
data.get("inversions", None),
|
||||
data.get("loras", None),
|
||||
data.get("models", None),
|
||||
)
|
||||
|
||||
|
||||
ERROR_NO_METADATA = "metadata must be provided"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
|
|||
from ..chain.result import ImageMetadata, StageResult
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image, save_result
|
||||
from ..output import read_metadata, save_image, save_result
|
||||
from ..params import (
|
||||
Border,
|
||||
HighresParams,
|
||||
|
@ -64,6 +64,31 @@ def add_safety_stage(
|
|||
)
|
||||
|
||||
|
||||
def add_thumbnail_output(
|
||||
server: ServerContext,
|
||||
images: StageResult,
|
||||
params: ImageParams,
|
||||
) -> None:
|
||||
"""
|
||||
Add a thumbnail image to the output, if requested.
|
||||
TODO: This should really be a stage.
|
||||
"""
|
||||
result_size = images.size()
|
||||
if (
|
||||
params.thumbnail
|
||||
and len(images) > 0
|
||||
and (
|
||||
result_size.width > server.thumbnail_size
|
||||
or result_size.height > server.thumbnail_size
|
||||
)
|
||||
):
|
||||
cover = images.as_images()[0]
|
||||
thumbnail = cover.copy()
|
||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||
|
||||
images.insert_image(0, thumbnail, images.metadata[0])
|
||||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
|
@ -131,22 +156,7 @@ def run_txt2img_pipeline(
|
|||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||
)
|
||||
|
||||
# add a thumbnail, if requested
|
||||
result_size = images.size()
|
||||
if (
|
||||
params.thumbnail
|
||||
and len(images) > 0
|
||||
and (
|
||||
result_size.width > server.thumbnail_size
|
||||
or result_size.height > server.thumbnail_size
|
||||
)
|
||||
):
|
||||
cover = images.as_images()[0]
|
||||
thumbnail = cover.copy()
|
||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||
|
||||
images.insert_image(0, thumbnail, images.metadata[0])
|
||||
|
||||
add_thumbnail_output(server, images, params)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
|
@ -231,19 +241,24 @@ def run_img2img_pipeline(
|
|||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# prep inputs
|
||||
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||
|
||||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
input_result, # terrible naming, I know
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
images.push_image(source, ImageMetadata.unknown_image())
|
||||
|
||||
add_thumbnail_output(server, images, params)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
|
@ -406,6 +421,10 @@ def run_inpaint_pipeline(
|
|||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# prep inputs
|
||||
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
|
@ -413,13 +432,13 @@ def run_inpaint_pipeline(
|
|||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(
|
||||
images=[source], metadata=[ImageMetadata.unknown_image()]
|
||||
), # TODO: load metadata from source image
|
||||
input_result,
|
||||
callback=progress,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
add_thumbnail_output(server, images, params)
|
||||
|
||||
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
|
||||
if full_res_inpaint:
|
||||
if is_debug():
|
||||
|
@ -488,16 +507,21 @@ def run_upscale_pipeline(
|
|||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# prep inputs
|
||||
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
||||
input_result,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
add_thumbnail_output(server, images, params)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
|
@ -543,18 +567,23 @@ def run_blend_pipeline(
|
|||
|
||||
add_safety_stage(server, chain)
|
||||
|
||||
# prep inputs
|
||||
input_metadata = [
|
||||
read_metadata(source) or ImageMetadata.unknown_image() for source in sources
|
||||
]
|
||||
input_result = StageResult(images=sources, metadata=input_metadata)
|
||||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback(reset=True)
|
||||
images = chain(
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
StageResult(
|
||||
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
|
||||
),
|
||||
input_result,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
add_thumbnail_output(server, images, params)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
|
|
|
@ -145,3 +145,19 @@ def save_metadata(
|
|||
f.write(dumps(json))
|
||||
logger.debug("saved image params to: %s", path)
|
||||
return path
|
||||
|
||||
|
||||
def read_metadata(
|
||||
image: Image.Image,
|
||||
) -> Optional[ImageMetadata]:
|
||||
exif_data = image._getexif()
|
||||
|
||||
if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web":
|
||||
return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote])
|
||||
|
||||
if ExifIFD.UserComment in exif_data:
|
||||
return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment])
|
||||
|
||||
# this could return ImageMetadata.unknown_image(), but that would not indicate whether the input
|
||||
# had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()`
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.result import ImageMetadata
|
||||
|
||||
|
||||
class ImageMetadataTests(unittest.TestCase):
|
||||
def test_image_metadata(self):
|
||||
pass
|
||||
|
||||
def test_from_exif_normal(self):
|
||||
exif_data = """test prompt
|
||||
Negative prompt: negative prompt
|
||||
Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5
|
||||
"""
|
||||
|
||||
metadata = ImageMetadata.from_exif(exif_data)
|
||||
self.assertEqual(metadata.params.prompt, "test prompt")
|
||||
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
|
||||
self.assertEqual(metadata.params.scheduler, "ddim")
|
||||
self.assertEqual(metadata.params.cfg, 4.0)
|
||||
self.assertEqual(metadata.params.steps, 30)
|
||||
self.assertEqual(metadata.params.seed, 5)
|
||||
|
||||
def test_from_exif_split(self):
|
||||
exif_data = """test prompt
|
||||
Negative prompt: negative prompt
|
||||
Sampler: ddim,
|
||||
CFG scale: 4.0,
|
||||
Steps: 30, Seed: 5
|
||||
"""
|
||||
|
||||
metadata = ImageMetadata.from_exif(exif_data)
|
||||
self.assertEqual(metadata.params.prompt, "test prompt")
|
||||
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
|
||||
self.assertEqual(metadata.params.scheduler, "ddim")
|
||||
self.assertEqual(metadata.params.cfg, 4.0)
|
||||
self.assertEqual(metadata.params.steps, 30)
|
||||
self.assertEqual(metadata.params.seed, 5)
|
||||
|
||||
|
||||
class StageResultTests(unittest.TestCase):
|
||||
def test_stage_result(self):
|
||||
pass
|
Loading…
Reference in New Issue