add networks to metadata
This commit is contained in:
parent
46098960d8
commit
19c91f70f5
|
@ -41,7 +41,9 @@ class PersistDiskStage(BaseStage):
|
||||||
upscale=metadata.upscale,
|
upscale=metadata.upscale,
|
||||||
border=metadata.border,
|
border=metadata.border,
|
||||||
highres=metadata.highres,
|
highres=metadata.highres,
|
||||||
) # TODO: inversions and loras
|
inversions=metadata.inversions,
|
||||||
|
loras=metadata.loras,
|
||||||
|
)
|
||||||
logger.info("saved image to %s", dest)
|
logger.info("saved image to %s", dest)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -13,6 +13,8 @@ class ImageMetadata:
|
||||||
params: ImageParams
|
params: ImageParams
|
||||||
size: Size
|
size: Size
|
||||||
upscale: UpscaleParams
|
upscale: UpscaleParams
|
||||||
|
inversions: Optional[List[Tuple[str, float]]]
|
||||||
|
loras: Optional[List[Tuple[str, float]]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -21,21 +23,28 @@ class ImageMetadata:
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
|
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||||
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.params = params
|
self.params = params
|
||||||
self.size = size
|
self.size = size
|
||||||
self.upscale = upscale
|
self.upscale = upscale
|
||||||
self.border = border
|
self.border = border
|
||||||
self.highres = highres
|
self.highres = highres
|
||||||
|
self.inversions = inversions
|
||||||
|
self.loras = loras
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self, server, outputs):
|
||||||
return json_params(
|
return json_params(
|
||||||
[],
|
server,
|
||||||
|
outputs,
|
||||||
self.params,
|
self.params,
|
||||||
self.size,
|
self.size,
|
||||||
upscale=self.upscale,
|
upscale=self.upscale,
|
||||||
border=self.border,
|
border=self.border,
|
||||||
highres=self.highres,
|
highres=self.highres,
|
||||||
|
inversions=self.inversions,
|
||||||
|
loras=self.loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,9 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
|
|
||||||
result = StageResult(source=sources)
|
result = StageResult(source=sources)
|
||||||
for image in output.images:
|
for image in output.images:
|
||||||
result.push_image(image, ImageMetadata(params, size))
|
result.push_image(
|
||||||
|
image, ImageMetadata(params, size, inversions=inversions, loras=loras)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("produced %s outputs", len(result))
|
logger.debug("produced %s outputs", len(result))
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -50,18 +50,23 @@ def hash_value(sha, param: Optional[Param]):
|
||||||
|
|
||||||
|
|
||||||
def json_params(
|
def json_params(
|
||||||
|
server: ServerContext,
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
|
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||||
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
parent: Optional[Dict] = None,
|
parent: Optional[Dict] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
"input_size": size.tojson(),
|
"input_size": size.tojson(),
|
||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
"params": params.tojson(),
|
"params": params.tojson(),
|
||||||
|
"inversions": {},
|
||||||
|
"loras": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
json["params"]["model"] = path.basename(params.model)
|
json["params"]["model"] = path.basename(params.model)
|
||||||
|
@ -83,6 +88,20 @@ def json_params(
|
||||||
|
|
||||||
json["size"] = output_size.tojson()
|
json["size"] = output_size.tojson()
|
||||||
|
|
||||||
|
if inversions is not None:
|
||||||
|
for name, weight in inversions:
|
||||||
|
hash = hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||||
|
).upper()
|
||||||
|
json["inversions"][name] = {"weight": weight, "hash": hash}
|
||||||
|
|
||||||
|
if loras is not None:
|
||||||
|
for name, weight in loras:
|
||||||
|
hash = hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||||
|
).upper()
|
||||||
|
json["loras"][name] = {"weight": weight, "hash": hash}
|
||||||
|
|
||||||
return json
|
return json
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,6 +229,7 @@ def save_image(
|
||||||
"maker note",
|
"maker note",
|
||||||
dumps(
|
dumps(
|
||||||
json_params(
|
json_params(
|
||||||
|
server,
|
||||||
[output],
|
[output],
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
|
@ -233,6 +253,7 @@ def save_image(
|
||||||
ExifIFD.MakerNote: UserComment.dump(
|
ExifIFD.MakerNote: UserComment.dump(
|
||||||
dumps(
|
dumps(
|
||||||
json_params(
|
json_params(
|
||||||
|
server,
|
||||||
[output],
|
[output],
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
|
@ -282,7 +303,7 @@ def save_params(
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(server.output_path, f"{output}.json")
|
path = base_join(server.output_path, f"{output}.json")
|
||||||
json = json_params(
|
json = json_params(
|
||||||
output, params, size, upscale=upscale, border=border, highres=highres
|
server, output, params, size, upscale=upscale, border=border, highres=highres
|
||||||
)
|
)
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
|
|
|
@ -218,7 +218,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("img2img job queued for: %s", job_name)
|
logger.info("img2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(
|
||||||
|
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -245,7 +247,9 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(
|
||||||
|
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -334,7 +338,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
json_params(
|
json_params(
|
||||||
output, params, size, upscale=upscale, border=expand, highres=highres
|
server,
|
||||||
|
output,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
upscale=upscale,
|
||||||
|
border=expand,
|
||||||
|
highres=highres,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -370,7 +380,9 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(
|
||||||
|
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# keys that are specially parsed by params and should not show up in with_args
|
# keys that are specially parsed by params and should not show up in with_args
|
||||||
|
@ -484,7 +496,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||||
return jsonify(json_params(output, step_params, base_size))
|
return jsonify(json_params(server, output, step_params, base_size))
|
||||||
|
|
||||||
|
|
||||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -526,7 +538,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(server, output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -546,7 +558,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(server, output, params, size))
|
||||||
|
|
||||||
|
|
||||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
Loading…
Reference in New Issue