diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index eac0f36c..d5cecdf7 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -14,7 +14,7 @@ def persist_disk( _job: WorkerContext, server: ServerContext, _stage: StageParams, - _params: ImageParams, + params: ImageParams, source: Image.Image, *, output: str, @@ -23,6 +23,6 @@ def persist_disk( ) -> Image.Image: source = stage_source or source - dest = save_image(server, output, source) + dest = save_image(server, output, source, params=params) logger.info("saved image to %s", dest) return source diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 4d217812..882050c6 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -314,8 +314,7 @@ def run_txt2img_pipeline( callback=progress, ) - dest = save_image(server, output, image) - save_params(server, output, params, size, upscale=upscale, highres=highres) + dest = save_image(server, output, image, params, size, upscale=upscale, highres=highres) run_gc([job.get_device()]) show_system_toast(f"finished txt2img job: {dest}") @@ -413,11 +412,12 @@ def run_img2img_pipeline( loras, ) + size = Size(*source.size) image = run_highres( job, server, params, - Size(source.width, source.height), + size, upscale, highres, image, @@ -436,9 +436,7 @@ def run_img2img_pipeline( callback=progress, ) - dest = save_image(server, output, image) - size = Size(*source.size) - save_params(server, output, params, size, upscale=upscale) + dest = save_image(server, output, image, params, size, upscale=upscale, highres=highres) run_gc([job.get_device()]) show_system_toast(f"finished img2img job: {dest}") @@ -504,12 +502,11 @@ def run_inpaint_pipeline( callback=progress, ) - dest = save_image(server, outputs[0], image) - save_params(server, outputs[0], params, size, upscale=upscale, border=border) + dest = save_image(server, outputs[0], image, params, size, upscale=upscale, border=border) del image - run_gc([job.get_device()]) + show_system_toast(f"finished inpaint job: {dest}") logger.info("finished inpaint job: %s", dest) @@ -547,12 +544,11 @@ def run_upscale_pipeline( loras, ) - dest = save_image(server, outputs[0], image) - save_params(server, outputs[0], params, size, upscale=upscale) + dest = save_image(server, outputs[0], image, params, size, upscale=upscale) del image - run_gc([job.get_device()]) + show_system_toast(f"finished upscale job: {dest}") logger.info("finished upscale job: %s", dest) @@ -586,11 +582,10 @@ def run_blend_pipeline( job, server, stage, params, image, upscale=upscale, callback=progress ) - dest = save_image(server, outputs[0], image) - save_params(server, outputs[0], params, size, upscale=upscale) + dest = save_image(server, outputs[0], image, params, size, upscale=upscale) del image - run_gc([job.get_device()]) + show_system_toast(f"finished blend job: {dest}") logger.info("finished blend job: %s", dest) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 892ffd6a..ec293d66 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -6,7 +6,9 @@ from struct import pack from time import time from typing import Any, List, Optional -from PIL import Image +from piexif import ExifIFD, ImageIFD, dump +from piexif.helper import UserComment +from PIL import Image, PngImagePlugin from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams from .server import ServerContext @@ -46,23 +48,37 @@ def json_params( json["params"]["model"] = path.basename(params.model) json["params"]["scheduler"] = params.scheduler + output_size = size if border is not None: json["border"] = border.tojson() - size = size.add_border(border) + output_size = output_size.add_border(border) if highres is not None: json["highres"] = highres.tojson() - size = highres.resize(size) + output_size = highres.resize(output_size) if upscale is not None: json["upscale"] = upscale.tojson() - size = upscale.resize(size) + output_size = upscale.resize(output_size) - json["size"] = size.tojson() + json["input_size"] = size.tojson() + json["size"] = output_size.tojson() return json +def str_params( + params: ImageParams, + size: Size, +) -> str: + return ( + f"{params.input_prompt}. Negative prompt: {params.input_negative_prompt}." + f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, " + f"Seed: {params.seed}, Size: {size.width}x{size.height}, Model hash: TODO, Model: {params.model}, " + f"Version: TODO, Tool: onnx-web" + ) + + def make_output_name( server: ServerContext, mode: str, @@ -99,9 +115,54 @@ def make_output_name( ] -def save_image(server: ServerContext, output: str, image: Image.Image) -> str: +def save_image( + server: ServerContext, + output: str, + image: Image.Image, + params: Optional[ImageParams] = None, + size: Optional[Size] = None, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, +) -> str: path = base_join(server.output_path, output) - image.save(path, format=server.image_format) + + if server.image_format == "png": + exif = PngImagePlugin.PngInfo() + + if params is not None: + exif.add_text("Parameters", str_params([output], params, size)) + exif.add_text( + "JSON Parameters", + json_params( + [output], + params, + size, + upscale=upscale, + border=border, + highres=highres, + ), + ) + + image.save(path, format=server.image_format, pnginfo=exif) + else: + exif = dump( + { + "0th": { + ExifIFD.UserComment: UserComment.dump( + str_params([output], params, size), encoding="unicode" + ), + ImageIFD.Make: "onnx-web", + ImageIFD.Model: "TODO", + # TODO: add JSON params + } + } + ) + image.save(path, format=server.image_format, exif=exif) + + if params is not None: + save_params(server, output, params, size, upscale=upscale, border=border, highres=highres) + logger.debug("saved output image to: %s", path) return path diff --git a/api/requirements/base.txt b/api/requirements/base.txt index 3135f90c..dfebf787 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -28,6 +28,7 @@ boto3==1.26.69 flask==2.2.2 flask-cors==3.0.10 jsonschema==4.17.3 +piexif==1.1.3 pyyaml==6.0 setproctitle==1.3.2 waitress==2.1.2 \ No newline at end of file