From 19c91f70f536043b5fe4c317f5e82b4c3fbaf508 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 2 Jan 2024 22:14:21 -0600 Subject: [PATCH] add networks to metadata --- api/onnx_web/chain/persist_disk.py | 4 +++- api/onnx_web/chain/result.py | 15 ++++++++++++--- api/onnx_web/chain/source_txt2img.py | 4 +++- api/onnx_web/output.py | 23 ++++++++++++++++++++++- api/onnx_web/server/api.py | 26 +++++++++++++++++++------- 5 files changed, 59 insertions(+), 13 deletions(-) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index e4496a2f..af023a68 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -41,7 +41,9 @@ class PersistDiskStage(BaseStage): upscale=metadata.upscale, border=metadata.border, highres=metadata.highres, - ) # TODO: inversions and loras + inversions=metadata.inversions, + loras=metadata.loras, + ) logger.info("saved image to %s", dest) return sources diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 0b8dc171..2e60877f 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import numpy as np from PIL import Image @@ -13,6 +13,8 @@ class ImageMetadata: params: ImageParams size: Size upscale: UpscaleParams + inversions: Optional[List[Tuple[str, float]]] + loras: Optional[List[Tuple[str, float]]] def __init__( self, @@ -21,21 +23,28 @@ class ImageMetadata: upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, highres: Optional[HighresParams] = None, + inversions: Optional[List[Tuple[str, float]]] = None, + loras: Optional[List[Tuple[str, float]]] = None, ) -> None: self.params = params self.size = size self.upscale = upscale self.border = border self.highres = highres + self.inversions = inversions + self.loras = loras - def tojson(self): + def tojson(self, server, outputs): return json_params( - [], + server, + outputs, self.params, self.size, upscale=self.upscale, border=self.border, highres=self.highres, + inversions=self.inversions, + loras=self.loras, ) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index cdf54a57..891e4776 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -157,7 +157,9 @@ class SourceTxt2ImgStage(BaseStage): result = StageResult(source=sources) for image in output.images: - result.push_image(image, ImageMetadata(params, size)) + result.push_image( + image, ImageMetadata(params, size, inversions=inversions, loras=loras) + ) logger.debug("produced %s outputs", len(result)) return result diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 08704ffe..816d2184 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -50,18 +50,23 @@ def hash_value(sha, param: Optional[Param]): def json_params( + server: ServerContext, outputs: List[str], params: ImageParams, size: Size, upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, highres: Optional[HighresParams] = None, + inversions: Optional[List[Tuple[str, float]]] = None, + loras: Optional[List[Tuple[str, float]]] = None, parent: Optional[Dict] = None, ) -> Any: json = { "input_size": size.tojson(), "outputs": outputs, "params": params.tojson(), + "inversions": {}, + "loras": {}, } json["params"]["model"] = path.basename(params.model) @@ -83,6 +88,20 @@ def json_params( json["size"] = output_size.tojson() + if inversions is not None: + for name, weight in inversions: + hash = hash_file( + resolve_tensor(path.join(server.model_path, "inversion", name)) + ).upper() + json["inversions"][name] = {"weight": weight, "hash": hash} + + if loras is not None: + for name, weight in loras: + hash = hash_file( + resolve_tensor(path.join(server.model_path, "lora", name)) + ).upper() + json["loras"][name] = {"weight": weight, "hash": hash} + return json @@ -210,6 +229,7 @@ def save_image( "maker note", dumps( json_params( + server, [output], params, size, @@ -233,6 +253,7 @@ def save_image( ExifIFD.MakerNote: UserComment.dump( dumps( json_params( + server, [output], params, size, @@ -282,7 +303,7 @@ def save_params( ) -> str: path = base_join(server.output_path, f"{output}.json") json = json_params( - output, params, size, upscale=upscale, border=border, highres=highres + server, output, params, size, upscale=upscale, border=border, highres=highres ) with open(path, "w") as f: f.write(dumps(json)) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index ebccbfe5..078bbf62 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -218,7 +218,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("img2img job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return jsonify( + json_params(server, output, params, size, upscale=upscale, highres=highres) + ) def txt2img(server: ServerContext, pool: DevicePoolExecutor): @@ -245,7 +247,9 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("txt2img job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return jsonify( + json_params(server, output, params, size, upscale=upscale, highres=highres) + ) def inpaint(server: ServerContext, pool: DevicePoolExecutor): @@ -334,7 +338,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): return jsonify( json_params( - output, params, size, upscale=upscale, border=expand, highres=highres + server, + output, + params, + size, + upscale=upscale, + border=expand, + highres=highres, ) ) @@ -370,7 +380,9 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return jsonify( + json_params(server, output, params, size, upscale=upscale, highres=highres) + ) # keys that are specially parsed by params and should not show up in with_args @@ -484,7 +496,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): ) step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) - return jsonify(json_params(output, step_params, base_size)) + return jsonify(json_params(server, output, step_params, base_size)) def blend(server: ServerContext, pool: DevicePoolExecutor): @@ -526,7 +538,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale)) + return jsonify(json_params(server, output, params, size, upscale=upscale)) def txt2txt(server: ServerContext, pool: DevicePoolExecutor): @@ -546,7 +558,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor): needs_device=device, ) - return jsonify(json_params(output, params, size)) + return jsonify(json_params(server, output, params, size)) def cancel(server: ServerContext, pool: DevicePoolExecutor):