feat(api): parse metadata from input images
This commit is contained in:
parent
065df23902
commit
a5fe52d2a2
|
@ -5,7 +5,7 @@ from ..chain.blend_img2img import BlendImg2ImgStage
|
||||||
from ..chain.edit_metadata import EditMetadataStage
|
from ..chain.edit_metadata import EditMetadataStage
|
||||||
from ..chain.upscale import stage_upscale_correction
|
from ..chain.upscale import stage_upscale_correction
|
||||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from .pipeline import ChainPipeline
|
from .pipeline import ChainPipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -70,7 +70,7 @@ def stage_highres(
|
||||||
# add highres parameters to the image metadata
|
# add highres parameters to the image metadata
|
||||||
chain.stage(
|
chain.stage(
|
||||||
EditMetadataStage(),
|
EditMetadataStage(),
|
||||||
stage.with_args(outscale=1),
|
stage.with_args(outscale=1, tile_size=SizeChart.max),
|
||||||
highres=highres,
|
highres=highres,
|
||||||
replace_params=params,
|
replace_params=params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
from re import compile
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -10,10 +11,12 @@ from ..convert.utils import resolve_tensor
|
||||||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||||
from ..server.context import ServerContext
|
from ..server.context import ServerContext
|
||||||
from ..server.load import get_extra_hashes
|
from ..server.load import get_extra_hashes
|
||||||
from ..utils import hash_file
|
from ..utils import hash_file, load_config_str
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
|
||||||
|
|
||||||
|
|
||||||
class NetworkMetadata:
|
class NetworkMetadata:
|
||||||
name: str
|
name: str
|
||||||
|
@ -238,6 +241,70 @@ class ImageMetadata:
|
||||||
|
|
||||||
return json
|
return json
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_exif(input: str) -> "ImageMetadata":
|
||||||
|
lines = input.splitlines()
|
||||||
|
prompt, maybe_negative, *rest = lines
|
||||||
|
|
||||||
|
# process negative prompt or put that line back into rest
|
||||||
|
if maybe_negative.startswith("Negative prompt:"):
|
||||||
|
negative_prompt = maybe_negative[len("Negative prompt:") :]
|
||||||
|
negative_prompt = negative_prompt.strip()
|
||||||
|
else:
|
||||||
|
rest.insert(0, maybe_negative)
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
rest = " ".join(rest)
|
||||||
|
other_params = rest.split(",")
|
||||||
|
|
||||||
|
# process other params
|
||||||
|
params = {}
|
||||||
|
size = None
|
||||||
|
for param in other_params:
|
||||||
|
key, value = param.split(":")
|
||||||
|
key = key.strip().lower()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
if key == "size":
|
||||||
|
width, height = value.split("x")
|
||||||
|
width = int(width.strip())
|
||||||
|
height = int(height.strip())
|
||||||
|
size = Size(width, height)
|
||||||
|
elif value.isdecimal():
|
||||||
|
value = int(value)
|
||||||
|
elif FLOAT_PATTERN.match(value) is not None:
|
||||||
|
value = float(value)
|
||||||
|
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
params = ImageParams(
|
||||||
|
"TODO",
|
||||||
|
"txt2img", # TODO: can this be detected?
|
||||||
|
params["sampler"],
|
||||||
|
prompt,
|
||||||
|
params["cfg scale"],
|
||||||
|
params["steps"],
|
||||||
|
params["seed"],
|
||||||
|
negative_prompt,
|
||||||
|
)
|
||||||
|
return ImageMetadata(params, size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(input: str) -> "ImageMetadata":
|
||||||
|
data = load_config_str(input)
|
||||||
|
# TODO: enforce schema
|
||||||
|
|
||||||
|
return ImageMetadata(
|
||||||
|
data["params"],
|
||||||
|
data["input_size"],
|
||||||
|
data.get("upscale", None),
|
||||||
|
data.get("border", None),
|
||||||
|
data.get("highres", None),
|
||||||
|
data.get("inversions", None),
|
||||||
|
data.get("loras", None),
|
||||||
|
data.get("models", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ERROR_NO_METADATA = "metadata must be provided"
|
ERROR_NO_METADATA = "metadata must be provided"
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
|
||||||
from ..chain.result import ImageMetadata, StageResult
|
from ..chain.result import ImageMetadata, StageResult
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image, save_result
|
from ..output import read_metadata, save_image, save_result
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
HighresParams,
|
HighresParams,
|
||||||
|
@ -64,6 +64,31 @@ def add_safety_stage(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_thumbnail_output(
|
||||||
|
server: ServerContext,
|
||||||
|
images: StageResult,
|
||||||
|
params: ImageParams,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add a thumbnail image to the output, if requested.
|
||||||
|
TODO: This should really be a stage.
|
||||||
|
"""
|
||||||
|
result_size = images.size()
|
||||||
|
if (
|
||||||
|
params.thumbnail
|
||||||
|
and len(images) > 0
|
||||||
|
and (
|
||||||
|
result_size.width > server.thumbnail_size
|
||||||
|
or result_size.height > server.thumbnail_size
|
||||||
|
)
|
||||||
|
):
|
||||||
|
cover = images.as_images()[0]
|
||||||
|
thumbnail = cover.copy()
|
||||||
|
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||||
|
|
||||||
|
images.insert_image(0, thumbnail, images.metadata[0])
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -131,22 +156,7 @@ def run_txt2img_pipeline(
|
||||||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||||
)
|
)
|
||||||
|
|
||||||
# add a thumbnail, if requested
|
add_thumbnail_output(server, images, params)
|
||||||
result_size = images.size()
|
|
||||||
if (
|
|
||||||
params.thumbnail
|
|
||||||
and len(images) > 0
|
|
||||||
and (
|
|
||||||
result_size.width > server.thumbnail_size
|
|
||||||
or result_size.height > server.thumbnail_size
|
|
||||||
)
|
|
||||||
):
|
|
||||||
cover = images.as_images()[0]
|
|
||||||
thumbnail = cover.copy()
|
|
||||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
|
||||||
|
|
||||||
images.insert_image(0, thumbnail, images.metadata[0])
|
|
||||||
|
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
|
@ -231,19 +241,24 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
add_safety_stage(server, chain)
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
|
# prep inputs
|
||||||
|
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||||
|
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
input_result, # terrible naming, I know
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.push_image(source, ImageMetadata.unknown_image())
|
images.push_image(source, ImageMetadata.unknown_image())
|
||||||
|
|
||||||
|
add_thumbnail_output(server, images, params)
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
|
@ -406,6 +421,10 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
add_safety_stage(server, chain)
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
|
# prep inputs
|
||||||
|
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||||
|
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
|
@ -413,13 +432,13 @@ def run_inpaint_pipeline(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(
|
input_result,
|
||||||
images=[source], metadata=[ImageMetadata.unknown_image()]
|
|
||||||
), # TODO: load metadata from source image
|
|
||||||
callback=progress,
|
callback=progress,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_thumbnail_output(server, images, params)
|
||||||
|
|
||||||
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
|
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)):
|
||||||
if full_res_inpaint:
|
if full_res_inpaint:
|
||||||
if is_debug():
|
if is_debug():
|
||||||
|
@ -488,16 +507,21 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
add_safety_stage(server, chain)
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
|
# prep inputs
|
||||||
|
input_metadata = read_metadata(source) or ImageMetadata.unknown_image()
|
||||||
|
input_result = StageResult(images=[source], metadata=[input_metadata])
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]),
|
input_result,
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_thumbnail_output(server, images, params)
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
|
@ -543,18 +567,23 @@ def run_blend_pipeline(
|
||||||
|
|
||||||
add_safety_stage(server, chain)
|
add_safety_stage(server, chain)
|
||||||
|
|
||||||
|
# prep inputs
|
||||||
|
input_metadata = [
|
||||||
|
read_metadata(source) or ImageMetadata.unknown_image() for source in sources
|
||||||
|
]
|
||||||
|
input_result = StageResult(images=sources, metadata=input_metadata)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback(reset=True)
|
progress = worker.get_progress_callback(reset=True)
|
||||||
images = chain(
|
images = chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
StageResult(
|
input_result,
|
||||||
images=sources, metadata=[ImageMetadata.unknown_image()] * len(sources)
|
|
||||||
),
|
|
||||||
callback=progress,
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_thumbnail_output(server, images, params)
|
||||||
save_result(server, images, worker.job)
|
save_result(server, images, worker.job)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
|
|
|
@ -145,3 +145,19 @@ def save_metadata(
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
logger.debug("saved image params to: %s", path)
|
logger.debug("saved image params to: %s", path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def read_metadata(
|
||||||
|
image: Image.Image,
|
||||||
|
) -> Optional[ImageMetadata]:
|
||||||
|
exif_data = image._getexif()
|
||||||
|
|
||||||
|
if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web":
|
||||||
|
return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote])
|
||||||
|
|
||||||
|
if ExifIFD.UserComment in exif_data:
|
||||||
|
return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment])
|
||||||
|
|
||||||
|
# this could return ImageMetadata.unknown_image(), but that would not indicate whether the input
|
||||||
|
# had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()`
|
||||||
|
return None
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.result import ImageMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMetadataTests(unittest.TestCase):
|
||||||
|
def test_image_metadata(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_from_exif_normal(self):
|
||||||
|
exif_data = """test prompt
|
||||||
|
Negative prompt: negative prompt
|
||||||
|
Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = ImageMetadata.from_exif(exif_data)
|
||||||
|
self.assertEqual(metadata.params.prompt, "test prompt")
|
||||||
|
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
|
||||||
|
self.assertEqual(metadata.params.scheduler, "ddim")
|
||||||
|
self.assertEqual(metadata.params.cfg, 4.0)
|
||||||
|
self.assertEqual(metadata.params.steps, 30)
|
||||||
|
self.assertEqual(metadata.params.seed, 5)
|
||||||
|
|
||||||
|
def test_from_exif_split(self):
|
||||||
|
exif_data = """test prompt
|
||||||
|
Negative prompt: negative prompt
|
||||||
|
Sampler: ddim,
|
||||||
|
CFG scale: 4.0,
|
||||||
|
Steps: 30, Seed: 5
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = ImageMetadata.from_exif(exif_data)
|
||||||
|
self.assertEqual(metadata.params.prompt, "test prompt")
|
||||||
|
self.assertEqual(metadata.params.negative_prompt, "negative prompt")
|
||||||
|
self.assertEqual(metadata.params.scheduler, "ddim")
|
||||||
|
self.assertEqual(metadata.params.cfg, 4.0)
|
||||||
|
self.assertEqual(metadata.params.steps, 30)
|
||||||
|
self.assertEqual(metadata.params.seed, 5)
|
||||||
|
|
||||||
|
|
||||||
|
class StageResultTests(unittest.TestCase):
|
||||||
|
def test_stage_result(self):
|
||||||
|
pass
|
Loading…
Reference in New Issue