1
0
Fork 0

feat: save a thumbnail for every image

This commit is contained in:
Sean Sube 2024-01-28 19:45:37 -06:00
parent 8a71934d74
commit 8c8be8fc08
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 167 additions and 110 deletions

View File

@ -358,6 +358,10 @@ class StageResult:
images: Optional[List[Image.Image]] images: Optional[List[Image.Image]]
metadata: List[ImageMetadata] metadata: List[ImageMetadata]
# output paths, filled in when the result is saved
outputs: Optional[List[str]]
thumbnails: Optional[List[str]]
@staticmethod @staticmethod
def empty(): def empty():
return StageResult(images=[]) return StageResult(images=[])
@ -385,6 +389,9 @@ class StageResult:
elif data_provided == 0: elif data_provided == 0:
raise ValueError("results must contain some data") raise ValueError("results must contain some data")
self.outputs = None
self.thumbnails = None
if source is not None: if source is not None:
self.arrays = source.arrays self.arrays = source.arrays
self.images = source.images self.images = source.images

View File

@ -15,7 +15,7 @@ from ..chain.highres import stage_highres
from ..chain.result import ImageMetadata, StageResult from ..chain.result import ImageMetadata, StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import read_metadata, save_image, save_result from ..output import make_output_names, read_metadata, save_image, save_result
from ..params import ( from ..params import (
Border, Border,
HighresParams, HighresParams,
@ -63,47 +63,6 @@ def add_safety_stage(
) )
def add_thumbnail_output(
server: ServerContext,
images: StageResult,
params: ImageParams,
) -> None:
"""
Add a thumbnail image to the output, if requested.
TODO: This should really be a stage.
"""
result_size = images.size()
if (
params.thumbnail
and len(images) > 0
and (
result_size.width > server.thumbnail_size
or result_size.height > server.thumbnail_size
)
):
cover = images.as_images()[0]
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
metadata = images.metadata[0]
metadata = metadata.with_args(
size=Size(server.thumbnail_size, server.thumbnail_size)
)
if metadata.highres is not None:
metadata.highres = metadata.highres.with_args(
scale=1,
)
if metadata.upscale is not None:
metadata.upscale = metadata.upscale.with_args(
scale=1,
outscale=1,
)
images.insert_image(0, thumbnail, metadata)
def run_txt2img_pipeline( def run_txt2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
@ -165,8 +124,7 @@ def run_txt2img_pipeline(
worker, server, params, StageResult.empty(), callback=progress, latents=latents worker, server, params, StageResult.empty(), callback=progress, latents=latents
) )
add_thumbnail_output(server, images, params) save_result(server, images, worker.job, save_thumbnails=params.thumbnail)
save_result(server, images, worker.job)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
@ -267,8 +225,7 @@ def run_img2img_pipeline(
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
images.push_image(source, ImageMetadata.unknown_image()) images.push_image(source, ImageMetadata.unknown_image())
add_thumbnail_output(server, images, params) save_result(server, images, worker.job, save_thumbnails=params.thumbnail)
save_result(server, images, worker.job)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
@ -446,9 +403,13 @@ def run_inpaint_pipeline(
latents=latents, latents=latents,
) )
add_thumbnail_output(server, images, params) # custom version of save for full-res inpainting
output_names = make_output_names(server, worker.job, len(images))
outputs = []
for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)): for image, metadata, output in zip(
images.as_images(), images.metadata, output_names
):
if full_res_inpaint: if full_res_inpaint:
if is_debug(): if is_debug():
save_image(server, "adjusted-output.png", image) save_image(server, "adjusted-output.png", image)
@ -457,13 +418,34 @@ def run_inpaint_pipeline(
image = original_source image = original_source
image.paste(mini_image, box=adj_mask_border) image.paste(mini_image, box=adj_mask_border)
save_image( outputs.append(
server, save_image(
f"{worker.job}_{i}.{server.image_format}", server,
image, output,
metadata, image,
metadata,
)
) )
thumbnails = None
if params.thumbnail:
thumbnail_names = make_output_names(
server, worker.job, len(images), suffix="thumbnail"
)
thumbnails = []
for image, thumbnail in zip(images.as_images(), thumbnail_names):
thumbnails.append(
save_image(
server,
thumbnail,
image,
)
)
images.outputs = outputs
images.thumbnails = thumbnails
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
@ -530,8 +512,7 @@ def run_upscale_pipeline(
callback=progress, callback=progress,
) )
add_thumbnail_output(server, images, params) save_result(server, images, worker.job, save_thumbnails=params.thumbnail)
save_result(server, images, worker.job)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
@ -592,8 +573,7 @@ def run_blend_pipeline(
callback=progress, callback=progress,
) )
add_thumbnail_output(server, images, params) save_result(server, images, worker.job, save_thumbnails=params.thumbnail)
save_result(server, images, worker.job)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])

View File

@ -22,7 +22,11 @@ def make_output_names(
count: int = 1, count: int = 1,
offset: int = 0, offset: int = 0,
extension: Optional[str] = None, extension: Optional[str] = None,
suffix: Optional[str] = None,
) -> List[str]: ) -> List[str]:
if suffix is not None:
job_name = f"{job_name}_{suffix}"
return [ return [
f"{job_name}_{i}.{extension or server.image_format}" f"{job_name}_{i}.{extension or server.image_format}"
for i in range(offset, count + offset) for i in range(offset, count + offset)
@ -63,14 +67,15 @@ def save_result(
server: ServerContext, server: ServerContext,
result: StageResult, result: StageResult,
base_name: str, base_name: str,
save_thumbnails: bool = False,
) -> List[str]: ) -> List[str]:
images = result.as_images() images = result.as_images()
outputs = make_output_names(server, base_name, len(images)) output_names = make_output_names(server, base_name, len(images))
logger.debug("saving %s images: %s", len(images), outputs) logger.debug("saving %s images: %s", len(images), output_names)
results = [] outputs = []
for image, metadata, filename in zip(images, result.metadata, outputs): for image, metadata, filename in zip(images, result.metadata, output_names):
results.append( outputs.append(
save_image( save_image(
server, server,
filename, filename,
@ -79,7 +84,33 @@ def save_result(
) )
) )
return results result.outputs = outputs
if save_thumbnails:
thumbnail_names = make_output_names(
server,
base_name,
len(images),
suffix="thumbnail",
)
logger.debug("saving %s thumbnails: %s", len(images), thumbnail_names)
thumbnails = []
for image, filename in zip(images, thumbnail_names):
thumbnail = image.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
thumbnails.append(
save_image(
server,
filename,
image,
)
)
result.thumbnails = thumbnails
return outputs
def save_image( def save_image(

View File

@ -18,7 +18,7 @@ from ..diffusers.run import (
run_upscale_pipeline, run_upscale_pipeline,
) )
from ..diffusers.utils import replace_wildcards from ..diffusers.utils import replace_wildcards
from ..output import make_job_name, make_output_names from ..output import make_job_name
from ..params import Size, StageParams, TileOrder from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline from ..transformers.run import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
@ -118,8 +118,9 @@ def image_reply(
stages: Progress = None, stages: Progress = None,
steps: Progress = None, steps: Progress = None,
tiles: Progress = None, tiles: Progress = None,
outputs: Optional[List[str]] = None,
metadata: Optional[List[ImageMetadata]] = None, metadata: Optional[List[ImageMetadata]] = None,
outputs: Optional[List[str]] = None,
thumbnails: Optional[List[str]] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if queue is None: if queue is None:
@ -158,6 +159,13 @@ def image_reply(
data["metadata"] = [m.tojson(server, [o]) for m, o in zip(metadata, outputs)] data["metadata"] = [m.tojson(server, [o]) for m, o in zip(metadata, outputs)]
data["outputs"] = outputs data["outputs"] = outputs
if thumbnails is not None:
if len(thumbnails) != len(outputs):
logger.error("thumbnails and outputs must be the same length")
return error_reply("thumbnails and outputs must be the same length")
data["thumbnails"] = thumbnails
return data return data
@ -692,12 +700,14 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
status, progress, queue = pool.status(job_name) status, progress, queue = pool.status(job_name)
if progress is not None: if progress is not None:
outputs = None
metadata = None metadata = None
if progress.result is not None and len(progress.result) > 0: outputs = None
# TODO: the names should be attached to the result when it is saved rather than recomputing them thumbnails = None
outputs = make_output_names(server, job_name, len(progress.result))
if progress.result is not None:
metadata = progress.result.metadata metadata = progress.result.metadata
outputs = progress.result.outputs
thumbnails = progress.result.thumbnails
records.append( records.append(
image_reply( image_reply(
@ -707,8 +717,9 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
stages=progress.stages, stages=progress.stages,
steps=progress.steps, steps=progress.steps,
tiles=progress.tiles, tiles=progress.tiles,
outputs=outputs,
metadata=metadata, metadata=metadata,
outputs=outputs,
thumbnails=thumbnails,
reason=progress.reason, reason=progress.reason,
) )
) )

View File

@ -500,6 +500,13 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
outputURL(image: SuccessJobResponse, index: number): string { outputURL(image: SuccessJobResponse, index: number): string {
return new URL(joinPath('output', image.outputs[index]), root).toString(); return new URL(joinPath('output', image.outputs[index]), root).toString();
}, },
thumbnailURL(image: SuccessJobResponse, index: number): Maybe<string> {
if (doesExist(image.thumbnails) && doesExist(image.thumbnails[index])) {
return new URL(joinPath('output', image.thumbnails[index]), root).toString();
}
return undefined;
},
}; };
const batchStatus = batcher({ const batchStatus = batcher({

View File

@ -1,9 +1,10 @@
import { Maybe } from '@apextoaster/js-utils';
import { ServerParams } from '../config.js'; import { ServerParams } from '../config.js';
import { ExtrasFile } from '../types/model.js';
import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js';
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js';
import { ExtrasFile } from '../types/model.js';
import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js';
export interface ApiClient { export interface ApiClient {
/** /**
@ -124,4 +125,6 @@ export interface ApiClient {
workers(): Promise<Array<unknown>>; workers(): Promise<Array<unknown>>;
outputURL(image: SuccessJobResponse, index: number): string; outputURL(image: SuccessJobResponse, index: number): string;
thumbnailURL(image: SuccessJobResponse, index: number): Maybe<string>;
} }

View File

@ -84,4 +84,7 @@ export const LOCAL_CLIENT = {
outputURL(image, index) { outputURL(image, index) {
throw new NoServerError(); throw new NoServerError();
}, },
thumbnailURL(image, index) {
throw new NoServerError();
},
} as ApiClient; } as ApiClient;

View File

@ -1,6 +1,6 @@
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'; import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material'; import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material'; import { Box, Card, CardActionArea, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip, Typography } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext, useMemo, useState } from 'react'; import { useContext, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -27,7 +27,7 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) {
export function ImageCard(props: ImageCardProps) { export function ImageCard(props: ImageCardProps) {
const { image } = props; const { image } = props;
const { metadata, outputs } = image; const { metadata, outputs, thumbnails } = image;
const [_hash, setHash] = useHash(); const [_hash, setHash] = useHash();
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>(); const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
@ -39,7 +39,7 @@ export function ImageCard(props: ImageCardProps) {
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow); const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
async function loadSource() { async function loadSource() {
const req = await fetch(url); const req = await fetch(outputURL);
return req.blob(); return req.blob();
} }
@ -85,12 +85,12 @@ export function ImageCard(props: ImageCardProps) {
} }
function downloadImage() { function downloadImage() {
window.open(url, '_blank'); window.open(outputURL, '_blank');
close(); close();
} }
function downloadMetadata() { function downloadMetadata() {
window.open(url + '.json', '_blank'); window.open(outputURL + '.json', '_blank');
close(); close();
} }
@ -107,17 +107,21 @@ export function ImageCard(props: ImageCardProps) {
return mustDefault(t(`${key}.${name}`), name); return mustDefault(t(`${key}.${name}`), name);
} }
const url = useMemo(() => client.outputURL(image, index), [image, index]); const outputURL = useMemo(() => client.outputURL(image, index), [image, index]);
const thumbnailURL = useMemo(() => client.thumbnailURL(image, index), [image, index]);
const previewURL = thumbnailURL ?? outputURL;
const model = getLabel('model', metadata[index].models[0].name); const model = getLabel('model', metadata[index].models[0].name);
const scheduler = getLabel('scheduler', metadata[index].params.scheduler); const scheduler = getLabel('scheduler', metadata[index].params.scheduler);
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}> return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }} <CardActionArea onClick={downloadImage}>
component='img' <CardMedia sx={{ height: config.params.height.default }}
image={url} component='img'
title={metadata[index].params.prompt} image={previewURL}
/> title={metadata[index].params.prompt}
/>
</CardActionArea>
<CardContent> <CardContent>
<Box textAlign='center'> <Box textAlign='center'>
<Grid container spacing={STANDARD_SPACING}> <Grid container spacing={STANDARD_SPACING}>
@ -136,7 +140,8 @@ export function ImageCard(props: ImageCardProps) {
</Tooltip> </Tooltip>
</GridItem> </GridItem>
<GridItem xs={4}> <GridItem xs={4}>
{visibleIndex(index)} of {outputs.length} <Typography>{visibleIndex(index)} of {outputs.length}</Typography>
{hasThumbnail(image, index) && <Typography>({t('image.thumbnail')})</Typography>}
</GridItem> </GridItem>
<GridItem xs={4}> <GridItem xs={4}>
<Tooltip title={t('tooltip.next')}> <Tooltip title={t('tooltip.next')}>
@ -240,3 +245,7 @@ export function selectActions(state: OnnxState) {
setUpscale: state.setUpscale, setUpscale: state.setUpscale,
}; };
} }
export function hasThumbnail(job: SuccessJobResponse, index: number) {
return doesExist(job.thumbnails) && doesExist(job.thumbnails[index]);
}

View File

@ -73,6 +73,9 @@ export const I18N_STRINGS_DE = {
}, },
}, },
}, },
image: {
thumbnail: 'Miniaturansicht',
},
loading: { loading: {
cancel: 'Stornieren', cancel: 'Stornieren',
progress: '{{steps.current}} von {{steps.total}} Schritten, {{tiles.current}} of {{tiles.total}} Kacheln, {{stages.current}} of {{stages.total}} Abschnitten', progress: '{{steps.current}} von {{steps.total}} Schritten, {{tiles.current}} of {{tiles.total}} Kacheln, {{stages.current}} of {{stages.total}} Abschnitten',

View File

@ -48,6 +48,9 @@ export const I18N_STRINGS_EN = {
history: { history: {
empty: 'No recent history. Press Generate to create an image.', empty: 'No recent history. Press Generate to create an image.',
}, },
image: {
thumbnail: 'Thumbnail',
},
input: { input: {
image: { image: {
empty: 'Please select an image.', empty: 'Please select an image.',

View File

@ -53,6 +53,9 @@ export const I18N_STRINGS_ES = {
lanczos: 'Lanczos', lanczos: 'Lanczos',
upscale: '', upscale: '',
}, },
image: {
thumbnail: 'Miniatura',
},
input: { input: {
image: { image: {
empty: 'Por favor, seleccione una imagen.', empty: 'Por favor, seleccione una imagen.',

View File

@ -53,6 +53,9 @@ export const I18N_STRINGS_FR = {
history: { history: {
empty: 'pas d\'histoire récente. appuyez sur générer pour créer une image.', empty: 'pas d\'histoire récente. appuyez sur générer pour créer une image.',
}, },
image: {
thumbnail: 'vignette',
},
input: { input: {
image: { image: {
empty: 'veuillez sélectionner une image', empty: 'veuillez sélectionner une image',

View File

@ -106,57 +106,51 @@ export interface RunningJobResponse extends BaseJobResponse {
status: JobStatus.RUNNING; status: JobStatus.RUNNING;
} }
export interface BaseSuccessJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
thumbnails?: Array<string>;
}
/** /**
* Successful txt2img image job with output keys and metadata. * Successful txt2img image job with output keys and metadata.
*/ */
export interface SuccessTxt2ImgJobResponse extends BaseJobResponse { export interface Txt2ImgSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>; metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>;
} }
/** /**
* Successful img2img job with output keys and metadata. * Successful img2img job with output keys and metadata.
*/ */
export interface SuccessImg2ImgJobResponse extends BaseJobResponse { export interface Img2ImgSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>; metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>;
} }
/** /**
* Successful inpaint job with output keys and metadata. * Successful inpaint job with output keys and metadata.
*/ */
export interface SuccessInpaintJobResponse extends BaseJobResponse { export interface InpaintSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>; metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>;
} }
/** /**
* Successful upscale job with output keys and metadata. * Successful upscale job with output keys and metadata.
*/ */
export interface SuccessUpscaleJobResponse extends BaseJobResponse { export interface UpscaleSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>; metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>;
} }
/** /**
* Successful blend job with output keys and metadata. * Successful blend job with output keys and metadata.
*/ */
export interface SuccessBlendJobResponse extends BaseJobResponse { export interface BlendSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>; metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>;
} }
/** /**
* Successful chain pipeline job with output keys and metadata. * Successful chain pipeline job with output keys and metadata.
*/ */
export interface SuccessChainJobResponse extends BaseJobResponse { export interface ChainSuccessJobResponse extends BaseSuccessJobResponse {
status: JobStatus.SUCCESS;
outputs: Array<string>;
metadata: Array<AnyImageMetadata>; metadata: Array<AnyImageMetadata>;
} }
@ -171,12 +165,12 @@ export interface UnknownJobResponse extends BaseJobResponse {
* All successful job types. * All successful job types.
*/ */
export type SuccessJobResponse export type SuccessJobResponse
= SuccessTxt2ImgJobResponse = Txt2ImgSuccessJobResponse
| SuccessImg2ImgJobResponse | Img2ImgSuccessJobResponse
| SuccessInpaintJobResponse | InpaintSuccessJobResponse
| SuccessUpscaleJobResponse | UpscaleSuccessJobResponse
| SuccessBlendJobResponse | BlendSuccessJobResponse
| SuccessChainJobResponse; | ChainSuccessJobResponse;
/** /**
* All job types. * All job types.