1
0
Fork 0

make response types match

This commit is contained in:
Sean Sube 2024-01-03 22:15:50 -06:00
parent c4b831fe5c
commit 46bcd5af86
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 43 additions and 25 deletions

View File

@ -1,13 +1,14 @@
from json import dumps from json import dumps
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, List, Optional from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..convert.utils import resolve_tensor from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext
from ..server.load import get_extra_hashes from ..server.load import get_extra_hashes
from ..utils import hash_file from ..utils import hash_file
@ -55,8 +56,8 @@ class ImageMetadata:
self.loras = loras self.loras = loras
self.models = models self.models = models
def to_auto1111(self, server, outputs) -> str: def get_model_hash(self, model: Optional[str] = None) -> Tuple[str, str]:
model_name = path.basename(path.normpath(self.params.model)) model_name = path.basename(path.normpath(model or self.params.model))
logger.debug("getting model hash for %s", model_name) logger.debug("getting model hash for %s", model_name)
model_hash = get_extra_hashes().get(model_name, None) model_hash = get_extra_hashes().get(model_name, None)
@ -66,7 +67,10 @@ class ImageMetadata:
with open(model_hash_path, "r") as f: with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r") model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown" return (model_name, model_hash or "unknown")
def to_exif(self, server) -> str:
model_name, model_hash = self.get_model_hash()
hash_map = { hash_map = {
model_name: model_hash, model_name: model_hash,
} }
@ -112,15 +116,17 @@ class ImageMetadata:
f"Hashes: {dumps(hash_map)}" f"Hashes: {dumps(hash_map)}"
) )
def tojson(self, server, outputs): def tojson(self, server: ServerContext, output: List[str]):
json = { json = {
"input_size": self.size.tojson(), "input_size": self.size.tojson(),
"outputs": outputs, "outputs": output,
"params": self.params.tojson(), "params": self.params.tojson(),
"inversions": {}, "inversions": [],
"loras": {}, "loras": [],
"models": [],
} }
# fix up some fields
json["params"]["model"] = path.basename(self.params.model) json["params"]["model"] = path.basename(self.params.model)
json["params"]["scheduler"] = self.params.scheduler # TODO: why tho? json["params"]["scheduler"] = self.params.scheduler # TODO: why tho?
@ -145,14 +151,21 @@ class ImageMetadata:
hash = hash_file( hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name)) resolve_tensor(path.join(server.model_path, "inversion", name))
).upper() ).upper()
json["inversions"][name] = {"weight": weight, "hash": hash} json["inversions"].append(
{"name": name, "weight": weight, "hash": hash}
)
if self.loras is not None: if self.loras is not None:
for name, weight in self.loras: for name, weight in self.loras:
hash = hash_file( hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name)) resolve_tensor(path.join(server.model_path, "lora", name))
).upper() ).upper()
json["loras"][name] = {"weight": weight, "hash": hash} json["loras"].append({"name": name, "weight": weight, "hash": hash})
if self.models is not None:
for name, weight in self.models:
name, hash = self.get_model_hash()
json["models"].append({"name": name, "weight": weight, "hash": hash})
return json return json

View File

@ -98,7 +98,7 @@ def save_image(
exif.add_text("model", server.server_version) exif.add_text("model", server.server_version)
exif.add_text( exif.add_text(
"parameters", "parameters",
metadata.to_auto1111(server, [output]), metadata.to_exif(server, [output]),
) )
image.save(path, format=server.image_format, pnginfo=exif) image.save(path, format=server.image_format, pnginfo=exif)
@ -111,7 +111,7 @@ def save_image(
encoding="unicode", encoding="unicode",
), ),
ExifIFD.UserComment: UserComment.dump( ExifIFD.UserComment: UserComment.dump(
metadata.to_auto1111(server, [output]), metadata.to_exif(server, [output]),
encoding="unicode", encoding="unicode",
), ),
ImageIFD.Make: "onnx-web", ImageIFD.Make: "onnx-web",

View File

@ -12,7 +12,6 @@ import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../sta
import { range, visibleIndex } from '../../utils.js'; import { range, visibleIndex } from '../../utils.js';
import { BLEND_SOURCES } from '../../constants.js'; import { BLEND_SOURCES } from '../../constants.js';
import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js'; import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js';
import { getApiRoot } from '../../config.js';
export interface ImageCardProps { export interface ImageCardProps {
image: SuccessJobResponse; image: SuccessJobResponse;
@ -110,8 +109,8 @@ export function ImageCard(props: ImageCardProps) {
const url = useMemo(() => client.outputURL(image, index), [image, index]); const url = useMemo(() => client.outputURL(image, index), [image, index]);
const model = getLabel('model', metadata[index].model); const model = getLabel('model', metadata[index].models[0].name);
const scheduler = getLabel('scheduler', metadata[index].scheduler); const scheduler = getLabel('scheduler', metadata[index].params.scheduler);
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}> return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }} <CardMedia sx={{ height: config.params.height.default }}

View File

@ -36,10 +36,6 @@ export function LoadingCard(props: LoadingCardProps) {
refetchInterval: POLL_TIME, refetchInterval: POLL_TIME,
}); });
function getReady() {
return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS;
}
function renderProgress() { function renderProgress() {
const progress = getProgress(ready.data); const progress = getProgress(ready.data);
const total = getTotal(ready.data); const total = getTotal(ready.data);
@ -57,10 +53,10 @@ export function LoadingCard(props: LoadingCardProps) {
}, [cancel.status]); }, [cancel.status]);
useEffect(() => { useEffect(() => {
if (ready.status === 'success' && getReady()) { if (ready.status === 'success') {
setReady(ready.data[0]); setReady(ready.data[0]);
} }
}, [ready.status, getReady(), getProgress(ready.data)]); }, [ready.status, getStatus(ready.data), getProgress(ready.data)]);
return <Card sx={{ maxWidth: params.width.default }}> return <Card sx={{ maxWidth: params.width.default }}>
<CardContent sx={{ height: params.height.default }}> <CardContent sx={{ height: params.height.default }}>
@ -134,3 +130,11 @@ export function getTotal(data: Maybe<Array<JobResponse>>) {
return 0; return 0;
} }
function getStatus(data: Maybe<Array<JobResponse>>) {
if (doesExist(data)) {
return data[0].status;
}
return JobStatus.PENDING;
}

View File

@ -19,16 +19,18 @@ export interface NetworkMetadata {
export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> { export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> {
input_size: Size; input_size: Size;
size: Size;
outputs: Array<string>; outputs: Array<string>;
params: TParams; params: TParams;
inversions: Array<NetworkMetadata>; inversions: Array<NetworkMetadata>;
loras: Array<NetworkMetadata>; loras: Array<NetworkMetadata>;
model: string; models: Array<NetworkMetadata>;
scheduler: string;
border: unknown; border: unknown; // TODO: type
highres: HighresParams; highres: HighresParams;
upscale: UpscaleParams; upscale: UpscaleParams;
size: Size;
type: TType; type: TType;
} }