From 8c8be8fc0802d1da56459b32abb2b4a249e262d9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 28 Jan 2024 19:45:37 -0600 Subject: [PATCH] feat: save a thumbnail for every image --- api/onnx_web/chain/result.py | 7 ++ api/onnx_web/diffusers/run.py | 94 +++++++++++---------------- api/onnx_web/output.py | 43 ++++++++++-- api/onnx_web/server/api.py | 25 +++++-- gui/src/client/api.ts | 7 ++ gui/src/client/base.ts | 11 ++-- gui/src/client/local.ts | 3 + gui/src/components/card/ImageCard.tsx | 33 ++++++---- gui/src/strings/de.ts | 3 + gui/src/strings/en.ts | 3 + gui/src/strings/es.ts | 3 + gui/src/strings/fr.ts | 3 + gui/src/types/api-v2.ts | 42 +++++------- 13 files changed, 167 insertions(+), 110 deletions(-) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 269ecd5a..918391a9 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -358,6 +358,10 @@ class StageResult: images: Optional[List[Image.Image]] metadata: List[ImageMetadata] + # output paths, filled in when the result is saved + outputs: Optional[List[str]] + thumbnails: Optional[List[str]] + @staticmethod def empty(): return StageResult(images=[]) @@ -385,6 +389,9 @@ class StageResult: elif data_provided == 0: raise ValueError("results must contain some data") + self.outputs = None + self.thumbnails = None + if source is not None: self.arrays = source.arrays self.images = source.images diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 6606be69..ea5cde79 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -15,7 +15,7 @@ from ..chain.highres import stage_highres from ..chain.result import ImageMetadata, StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image -from ..output import read_metadata, save_image, save_result +from ..output import make_output_names, read_metadata, save_image, save_result from ..params import ( Border, HighresParams, @@ -63,47 +63,6 @@ def add_safety_stage( ) -def add_thumbnail_output( - server: ServerContext, - images: StageResult, - params: ImageParams, -) -> None: - """ - Add a thumbnail image to the output, if requested. - TODO: This should really be a stage. - """ - result_size = images.size() - if ( - params.thumbnail - and len(images) > 0 - and ( - result_size.width > server.thumbnail_size - or result_size.height > server.thumbnail_size - ) - ): - cover = images.as_images()[0] - thumbnail = cover.copy() - thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) - - metadata = images.metadata[0] - metadata = metadata.with_args( - size=Size(server.thumbnail_size, server.thumbnail_size) - ) - - if metadata.highres is not None: - metadata.highres = metadata.highres.with_args( - scale=1, - ) - - if metadata.upscale is not None: - metadata.upscale = metadata.upscale.with_args( - scale=1, - outscale=1, - ) - - images.insert_image(0, thumbnail, metadata) - - def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, @@ -165,8 +124,7 @@ def run_txt2img_pipeline( worker, server, params, StageResult.empty(), callback=progress, latents=latents ) - add_thumbnail_output(server, images, params) - save_result(server, images, worker.job) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) @@ -267,8 +225,7 @@ def run_img2img_pipeline( if source_filter is not None and source_filter != "none": images.push_image(source, ImageMetadata.unknown_image()) - add_thumbnail_output(server, images, params) - save_result(server, images, worker.job) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) @@ -446,9 +403,13 @@ def run_inpaint_pipeline( latents=latents, ) - add_thumbnail_output(server, images, params) + # custom version of save for full-res inpainting + output_names = make_output_names(server, worker.job, len(images)) + outputs = [] - for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)): + for image, metadata, output in zip( + images.as_images(), images.metadata, output_names + ): if full_res_inpaint: if is_debug(): save_image(server, "adjusted-output.png", image) @@ -457,13 +418,34 @@ def run_inpaint_pipeline( image = original_source image.paste(mini_image, box=adj_mask_border) - save_image( - server, - f"{worker.job}_{i}.{server.image_format}", - image, - metadata, + outputs.append( + save_image( + server, + output, + image, + metadata, + ) ) + thumbnails = None + if params.thumbnail: + thumbnail_names = make_output_names( + server, worker.job, len(images), suffix="thumbnail" + ) + thumbnails = [] + + for image, thumbnail in zip(images.as_images(), thumbnail_names): + thumbnails.append( + save_image( + server, + thumbnail, + image, + ) + ) + + images.outputs = outputs + images.thumbnails = thumbnails + # clean up run_gc([worker.get_device()]) @@ -530,8 +512,7 @@ def run_upscale_pipeline( callback=progress, ) - add_thumbnail_output(server, images, params) - save_result(server, images, worker.job) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) @@ -592,8 +573,7 @@ def run_blend_pipeline( callback=progress, ) - add_thumbnail_output(server, images, params) - save_result(server, images, worker.job) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 8146d546..44b1a414 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -22,7 +22,11 @@ def make_output_names( count: int = 1, offset: int = 0, extension: Optional[str] = None, + suffix: Optional[str] = None, ) -> List[str]: + if suffix is not None: + job_name = f"{job_name}_{suffix}" + return [ f"{job_name}_{i}.{extension or server.image_format}" for i in range(offset, count + offset) @@ -63,14 +67,15 @@ def save_result( server: ServerContext, result: StageResult, base_name: str, + save_thumbnails: bool = False, ) -> List[str]: images = result.as_images() - outputs = make_output_names(server, base_name, len(images)) - logger.debug("saving %s images: %s", len(images), outputs) + output_names = make_output_names(server, base_name, len(images)) + logger.debug("saving %s images: %s", len(images), output_names) - results = [] - for image, metadata, filename in zip(images, result.metadata, outputs): - results.append( + outputs = [] + for image, metadata, filename in zip(images, result.metadata, output_names): + outputs.append( save_image( server, filename, @@ -79,7 +84,33 @@ def save_result( ) ) - return results + result.outputs = outputs + + if save_thumbnails: + thumbnail_names = make_output_names( + server, + base_name, + len(images), + suffix="thumbnail", + ) + logger.debug("saving %s thumbnails: %s", len(images), thumbnail_names) + + thumbnails = [] + for image, filename in zip(images, thumbnail_names): + thumbnail = image.copy() + thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) + + thumbnails.append( + save_image( + server, + filename, + image, + ) + ) + + result.thumbnails = thumbnails + + return outputs def save_image( diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 3b8a8814..f269435d 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -18,7 +18,7 @@ from ..diffusers.run import ( run_upscale_pipeline, ) from ..diffusers.utils import replace_wildcards -from ..output import make_job_name, make_output_names +from ..output import make_job_name from ..params import Size, StageParams, TileOrder from ..transformers.run import run_txt2txt_pipeline from ..utils import ( @@ -118,8 +118,9 @@ def image_reply( stages: Progress = None, steps: Progress = None, tiles: Progress = None, - outputs: Optional[List[str]] = None, metadata: Optional[List[ImageMetadata]] = None, + outputs: Optional[List[str]] = None, + thumbnails: Optional[List[str]] = None, reason: Optional[str] = None, ) -> Dict[str, Any]: if queue is None: @@ -158,6 +159,13 @@ def image_reply( data["metadata"] = [m.tojson(server, [o]) for m, o in zip(metadata, outputs)] data["outputs"] = outputs + if thumbnails is not None: + if len(thumbnails) != len(outputs): + logger.error("thumbnails and outputs must be the same length") + return error_reply("thumbnails and outputs must be the same length") + + data["thumbnails"] = thumbnails + return data @@ -692,12 +700,14 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): status, progress, queue = pool.status(job_name) if progress is not None: - outputs = None metadata = None - if progress.result is not None and len(progress.result) > 0: - # TODO: the names should be attached to the result when it is saved rather than recomputing them - outputs = make_output_names(server, job_name, len(progress.result)) + outputs = None + thumbnails = None + + if progress.result is not None: metadata = progress.result.metadata + outputs = progress.result.outputs + thumbnails = progress.result.thumbnails records.append( image_reply( @@ -707,8 +717,9 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): stages=progress.stages, steps=progress.steps, tiles=progress.tiles, - outputs=outputs, metadata=metadata, + outputs=outputs, + thumbnails=thumbnails, reason=progress.reason, ) ) diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index c1297df0..17efd478 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -500,6 +500,13 @@ export function makeClient(root: string, batchInterval: number, token: Maybe { + if (doesExist(image.thumbnails) && doesExist(image.thumbnails[index])) { + return new URL(joinPath('output', image.thumbnails[index]), root).toString(); + } + + return undefined; + }, }; const batchStatus = batcher({ diff --git a/gui/src/client/base.ts b/gui/src/client/base.ts index 0c1df7b6..ad76f5b9 100644 --- a/gui/src/client/base.ts +++ b/gui/src/client/base.ts @@ -1,9 +1,10 @@ +import { Maybe } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; -import { ExtrasFile } from '../types/model.js'; -import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js'; -import { ChainPipeline } from '../types/chain.js'; -import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js'; import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; +import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js'; +import { ChainPipeline } from '../types/chain.js'; +import { ExtrasFile } from '../types/model.js'; +import { BlendParams, HighresParams, Img2ImgParams, InpaintParams, ModelParams, OutpaintParams, Txt2ImgParams, UpscaleParams, UpscaleReqParams } from '../types/params.js'; export interface ApiClient { /** @@ -124,4 +125,6 @@ export interface ApiClient { workers(): Promise>; outputURL(image: SuccessJobResponse, index: number): string; + + thumbnailURL(image: SuccessJobResponse, index: number): Maybe; } diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 64804876..eefbaa12 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -84,4 +84,7 @@ export const LOCAL_CLIENT = { outputURL(image, index) { throw new NoServerError(); }, + thumbnailURL(image, index) { + throw new NoServerError(); + }, } as ApiClient; diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index 8a431cb2..82c6f5b1 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -1,6 +1,6 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'; import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material'; -import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material'; +import { Box, Card, CardActionArea, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip, Typography } from '@mui/material'; import * as React from 'react'; import { useContext, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -27,7 +27,7 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) { export function ImageCard(props: ImageCardProps) { const { image } = props; - const { metadata, outputs } = image; + const { metadata, outputs, thumbnails } = image; const [_hash, setHash] = useHash(); const [blendAnchor, setBlendAnchor] = useState>(); @@ -39,7 +39,7 @@ export function ImageCard(props: ImageCardProps) { const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow); async function loadSource() { - const req = await fetch(url); + const req = await fetch(outputURL); return req.blob(); } @@ -85,12 +85,12 @@ export function ImageCard(props: ImageCardProps) { } function downloadImage() { - window.open(url, '_blank'); + window.open(outputURL, '_blank'); close(); } function downloadMetadata() { - window.open(url + '.json', '_blank'); + window.open(outputURL + '.json', '_blank'); close(); } @@ -107,17 +107,21 @@ export function ImageCard(props: ImageCardProps) { return mustDefault(t(`${key}.${name}`), name); } - const url = useMemo(() => client.outputURL(image, index), [image, index]); + const outputURL = useMemo(() => client.outputURL(image, index), [image, index]); + const thumbnailURL = useMemo(() => client.thumbnailURL(image, index), [image, index]); + const previewURL = thumbnailURL ?? outputURL; const model = getLabel('model', metadata[index].models[0].name); const scheduler = getLabel('scheduler', metadata[index].params.scheduler); return - + + + @@ -136,7 +140,8 @@ export function ImageCard(props: ImageCardProps) { - {visibleIndex(index)} of {outputs.length} + {visibleIndex(index)} of {outputs.length} + {hasThumbnail(image, index) && ({t('image.thumbnail')})} @@ -240,3 +245,7 @@ export function selectActions(state: OnnxState) { setUpscale: state.setUpscale, }; } + +export function hasThumbnail(job: SuccessJobResponse, index: number) { + return doesExist(job.thumbnails) && doesExist(job.thumbnails[index]); +} diff --git a/gui/src/strings/de.ts b/gui/src/strings/de.ts index 73ce3853..74d65955 100644 --- a/gui/src/strings/de.ts +++ b/gui/src/strings/de.ts @@ -73,6 +73,9 @@ export const I18N_STRINGS_DE = { }, }, }, + image: { + thumbnail: 'Miniaturansicht', + }, loading: { cancel: 'Stornieren', progress: '{{steps.current}} von {{steps.total}} Schritten, {{tiles.current}} of {{tiles.total}} Kacheln, {{stages.current}} of {{stages.total}} Abschnitten', diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index 387f67e7..9f8f0cdb 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -48,6 +48,9 @@ export const I18N_STRINGS_EN = { history: { empty: 'No recent history. Press Generate to create an image.', }, + image: { + thumbnail: 'Thumbnail', + }, input: { image: { empty: 'Please select an image.', diff --git a/gui/src/strings/es.ts b/gui/src/strings/es.ts index 1011ec85..9d6cada9 100644 --- a/gui/src/strings/es.ts +++ b/gui/src/strings/es.ts @@ -53,6 +53,9 @@ export const I18N_STRINGS_ES = { lanczos: 'Lanczos', upscale: '', }, + image: { + thumbnail: 'Miniatura', + }, input: { image: { empty: 'Por favor, seleccione una imagen.', diff --git a/gui/src/strings/fr.ts b/gui/src/strings/fr.ts index affdb9eb..344d9235 100644 --- a/gui/src/strings/fr.ts +++ b/gui/src/strings/fr.ts @@ -53,6 +53,9 @@ export const I18N_STRINGS_FR = { history: { empty: 'pas d\'histoire récente. appuyez sur générer pour créer une image.', }, + image: { + thumbnail: 'vignette', + }, input: { image: { empty: 'veuillez sélectionner une image', diff --git a/gui/src/types/api-v2.ts b/gui/src/types/api-v2.ts index 2545877b..5692983a 100644 --- a/gui/src/types/api-v2.ts +++ b/gui/src/types/api-v2.ts @@ -106,57 +106,51 @@ export interface RunningJobResponse extends BaseJobResponse { status: JobStatus.RUNNING; } +export interface BaseSuccessJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + outputs: Array; + thumbnails?: Array; +} + /** * Successful txt2img image job with output keys and metadata. */ -export interface SuccessTxt2ImgJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface Txt2ImgSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array>; } /** * Successful img2img job with output keys and metadata. */ -export interface SuccessImg2ImgJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface Img2ImgSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array>; } /** * Successful inpaint job with output keys and metadata. */ -export interface SuccessInpaintJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface InpaintSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array>; } /** * Successful upscale job with output keys and metadata. */ -export interface SuccessUpscaleJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface UpscaleSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array>; } /** * Successful blend job with output keys and metadata. */ -export interface SuccessBlendJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface BlendSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array>; } /** * Successful chain pipeline job with output keys and metadata. */ -export interface SuccessChainJobResponse extends BaseJobResponse { - status: JobStatus.SUCCESS; - outputs: Array; +export interface ChainSuccessJobResponse extends BaseSuccessJobResponse { metadata: Array; } @@ -171,12 +165,12 @@ export interface UnknownJobResponse extends BaseJobResponse { * All successful job types. */ export type SuccessJobResponse - = SuccessTxt2ImgJobResponse - | SuccessImg2ImgJobResponse - | SuccessInpaintJobResponse - | SuccessUpscaleJobResponse - | SuccessBlendJobResponse - | SuccessChainJobResponse; + = Txt2ImgSuccessJobResponse + | Img2ImgSuccessJobResponse + | InpaintSuccessJobResponse + | UpscaleSuccessJobResponse + | BlendSuccessJobResponse + | ChainSuccessJobResponse; /** * All job types.