add networks to metadata
This commit is contained in:
parent
46098960d8
commit
19c91f70f5
|
@ -41,7 +41,9 @@ class PersistDiskStage(BaseStage):
|
|||
upscale=metadata.upscale,
|
||||
border=metadata.border,
|
||||
highres=metadata.highres,
|
||||
) # TODO: inversions and loras
|
||||
inversions=metadata.inversions,
|
||||
loras=metadata.loras,
|
||||
)
|
||||
logger.info("saved image to %s", dest)
|
||||
|
||||
return sources
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -13,6 +13,8 @@ class ImageMetadata:
|
|||
params: ImageParams
|
||||
size: Size
|
||||
upscale: UpscaleParams
|
||||
inversions: Optional[List[Tuple[str, float]]]
|
||||
loras: Optional[List[Tuple[str, float]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -21,21 +23,28 @@ class ImageMetadata:
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
) -> None:
|
||||
self.params = params
|
||||
self.size = size
|
||||
self.upscale = upscale
|
||||
self.border = border
|
||||
self.highres = highres
|
||||
self.inversions = inversions
|
||||
self.loras = loras
|
||||
|
||||
def tojson(self):
|
||||
def tojson(self, server, outputs):
|
||||
return json_params(
|
||||
[],
|
||||
server,
|
||||
outputs,
|
||||
self.params,
|
||||
self.size,
|
||||
upscale=self.upscale,
|
||||
border=self.border,
|
||||
highres=self.highres,
|
||||
inversions=self.inversions,
|
||||
loras=self.loras,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -157,7 +157,9 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
|
||||
result = StageResult(source=sources)
|
||||
for image in output.images:
|
||||
result.push_image(image, ImageMetadata(params, size))
|
||||
result.push_image(
|
||||
image, ImageMetadata(params, size, inversions=inversions, loras=loras)
|
||||
)
|
||||
|
||||
logger.debug("produced %s outputs", len(result))
|
||||
return result
|
||||
|
|
|
@ -50,18 +50,23 @@ def hash_value(sha, param: Optional[Param]):
|
|||
|
||||
|
||||
def json_params(
|
||||
server: ServerContext,
|
||||
outputs: List[str],
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
parent: Optional[Dict] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"input_size": size.tojson(),
|
||||
"outputs": outputs,
|
||||
"params": params.tojson(),
|
||||
"inversions": {},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
json["params"]["model"] = path.basename(params.model)
|
||||
|
@ -83,6 +88,20 @@ def json_params(
|
|||
|
||||
json["size"] = output_size.tojson()
|
||||
|
||||
if inversions is not None:
|
||||
for name, weight in inversions:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper()
|
||||
json["inversions"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
if loras is not None:
|
||||
for name, weight in loras:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper()
|
||||
json["loras"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
return json
|
||||
|
||||
|
||||
|
@ -210,6 +229,7 @@ def save_image(
|
|||
"maker note",
|
||||
dumps(
|
||||
json_params(
|
||||
server,
|
||||
[output],
|
||||
params,
|
||||
size,
|
||||
|
@ -233,6 +253,7 @@ def save_image(
|
|||
ExifIFD.MakerNote: UserComment.dump(
|
||||
dumps(
|
||||
json_params(
|
||||
server,
|
||||
[output],
|
||||
params,
|
||||
size,
|
||||
|
@ -282,7 +303,7 @@ def save_params(
|
|||
) -> str:
|
||||
path = base_join(server.output_path, f"{output}.json")
|
||||
json = json_params(
|
||||
output, params, size, upscale=upscale, border=border, highres=highres
|
||||
server, output, params, size, upscale=upscale, border=border, highres=highres
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(dumps(json))
|
||||
|
|
|
@ -218,7 +218,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
|
||||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -245,7 +247,9 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("txt2img job queued for: %s", job_name)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
|
||||
|
||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -334,7 +338,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
return jsonify(
|
||||
json_params(
|
||||
output, params, size, upscale=upscale, border=expand, highres=highres
|
||||
server,
|
||||
output,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=expand,
|
||||
highres=highres,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -370,7 +380,9 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
|
||||
|
||||
# keys that are specially parsed by params and should not show up in with_args
|
||||
|
@ -484,7 +496,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
return jsonify(json_params(output, step_params, base_size))
|
||||
return jsonify(json_params(server, output, step_params, base_size))
|
||||
|
||||
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -526,7 +538,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
return jsonify(json_params(server, output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -546,7 +558,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
|||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
return jsonify(json_params(server, output, params, size))
|
||||
|
||||
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
Loading…
Reference in New Issue