1
0
Fork 0

add networks to metadata

This commit is contained in:
Sean Sube 2024-01-02 22:14:21 -06:00
parent 46098960d8
commit 19c91f70f5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 59 additions and 13 deletions

View File

@ -41,7 +41,9 @@ class PersistDiskStage(BaseStage):
upscale=metadata.upscale, upscale=metadata.upscale,
border=metadata.border, border=metadata.border,
highres=metadata.highres, highres=metadata.highres,
) # TODO: inversions and loras inversions=metadata.inversions,
loras=metadata.loras,
)
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return sources return sources

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -13,6 +13,8 @@ class ImageMetadata:
params: ImageParams params: ImageParams
size: Size size: Size
upscale: UpscaleParams upscale: UpscaleParams
inversions: Optional[List[Tuple[str, float]]]
loras: Optional[List[Tuple[str, float]]]
def __init__( def __init__(
self, self,
@ -21,21 +23,28 @@ class ImageMetadata:
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
) -> None: ) -> None:
self.params = params self.params = params
self.size = size self.size = size
self.upscale = upscale self.upscale = upscale
self.border = border self.border = border
self.highres = highres self.highres = highres
self.inversions = inversions
self.loras = loras
def tojson(self): def tojson(self, server, outputs):
return json_params( return json_params(
[], server,
outputs,
self.params, self.params,
self.size, self.size,
upscale=self.upscale, upscale=self.upscale,
border=self.border, border=self.border,
highres=self.highres, highres=self.highres,
inversions=self.inversions,
loras=self.loras,
) )

View File

@ -157,7 +157,9 @@ class SourceTxt2ImgStage(BaseStage):
result = StageResult(source=sources) result = StageResult(source=sources)
for image in output.images: for image in output.images:
result.push_image(image, ImageMetadata(params, size)) result.push_image(
image, ImageMetadata(params, size, inversions=inversions, loras=loras)
)
logger.debug("produced %s outputs", len(result)) logger.debug("produced %s outputs", len(result))
return result return result

View File

@ -50,18 +50,23 @@ def hash_value(sha, param: Optional[Param]):
def json_params( def json_params(
server: ServerContext,
outputs: List[str], outputs: List[str],
params: ImageParams, params: ImageParams,
size: Size, size: Size,
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
parent: Optional[Dict] = None, parent: Optional[Dict] = None,
) -> Any: ) -> Any:
json = { json = {
"input_size": size.tojson(), "input_size": size.tojson(),
"outputs": outputs, "outputs": outputs,
"params": params.tojson(), "params": params.tojson(),
"inversions": {},
"loras": {},
} }
json["params"]["model"] = path.basename(params.model) json["params"]["model"] = path.basename(params.model)
@ -83,6 +88,20 @@ def json_params(
json["size"] = output_size.tojson() json["size"] = output_size.tojson()
if inversions is not None:
for name, weight in inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
json["inversions"][name] = {"weight": weight, "hash": hash}
if loras is not None:
for name, weight in loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
json["loras"][name] = {"weight": weight, "hash": hash}
return json return json
@ -210,6 +229,7 @@ def save_image(
"maker note", "maker note",
dumps( dumps(
json_params( json_params(
server,
[output], [output],
params, params,
size, size,
@ -233,6 +253,7 @@ def save_image(
ExifIFD.MakerNote: UserComment.dump( ExifIFD.MakerNote: UserComment.dump(
dumps( dumps(
json_params( json_params(
server,
[output], [output],
params, params,
size, size,
@ -282,7 +303,7 @@ def save_params(
) -> str: ) -> str:
path = base_join(server.output_path, f"{output}.json") path = base_join(server.output_path, f"{output}.json")
json = json_params( json = json_params(
output, params, size, upscale=upscale, border=border, highres=highres server, output, params, size, upscale=upscale, border=border, highres=highres
) )
with open(path, "w") as f: with open(path, "w") as f:
f.write(dumps(json)) f.write(dumps(json))

View File

@ -218,7 +218,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("img2img job queued for: %s", job_name) logger.info("img2img job queued for: %s", job_name)
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
@ -245,7 +247,9 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("txt2img job queued for: %s", job_name) logger.info("txt2img job queued for: %s", job_name)
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
def inpaint(server: ServerContext, pool: DevicePoolExecutor): def inpaint(server: ServerContext, pool: DevicePoolExecutor):
@ -334,7 +338,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
return jsonify( return jsonify(
json_params( json_params(
output, params, size, upscale=upscale, border=expand, highres=highres server,
output,
params,
size,
upscale=upscale,
border=expand,
highres=highres,
) )
) )
@ -370,7 +380,9 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
# keys that are specially parsed by params and should not show up in with_args # keys that are specially parsed by params and should not show up in with_args
@ -484,7 +496,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
) )
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(output, step_params, base_size)) return jsonify(json_params(server, output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -526,7 +538,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(server, output, params, size, upscale=upscale))
def txt2txt(server: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
@ -546,7 +558,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
needs_device=device, needs_device=device,
) )
return jsonify(json_params(output, params, size)) return jsonify(json_params(server, output, params, size))
def cancel(server: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):