diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index c34e8b54..d0d94cfc 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -108,7 +108,17 @@ def list_extra_strings(server: ServerContext): return jsonify(get_extra_strings()) +def list_filters(server: ServerContext): + mask_filters = list(get_mask_filters().keys()) + source_filters = list(get_source_filters().keys()) + return jsonify({ + "mask": mask_filters, + "source": source_filters, + }) + + def list_mask_filters(server: ServerContext): + logger.info("dedicated list endpoint for mask filters is deprecated") return jsonify(list(get_mask_filters().keys())) @@ -502,6 +512,7 @@ def status(server: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): return [ app.route("/api")(wrap_route(introspect, server, app=app)), + app.route("/api/settings/filters")(wrap_route(list_filters, server)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), app.route("/api/settings/models")(wrap_route(list_models, server)), app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)), diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 4b1ab699..08030df2 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -67,6 +67,8 @@ export interface Txt2ImgParams extends BaseImgParams { */ export interface Img2ImgParams extends BaseImgParams { source: Blob; + + sourceFilter?: string; strength: number; } @@ -201,10 +203,15 @@ export interface NetworkModel { // TODO: add layer/token count } +export interface FilterResponse { + mask: Array; + source: Array; +} + /** * List of available models. */ -export interface ModelsResponse { +export interface ModelResponse { correction: Array; diffusion: Array; networks: Array; @@ -253,12 +260,12 @@ export interface ApiClient { /** * List the available filter masks for inpaint. */ - masks(): Promise>; + filters(): Promise; /** * List the available models. */ - models(): Promise; + models(): Promise; /** * List the available noise sources for inpaint. @@ -433,15 +440,15 @@ export function makeClient(root: string, f = fetch): ApiClient { } return { - async masks(): Promise> { - const path = makeApiUrl(root, 'settings', 'masks'); + async filters(): Promise { + const path = makeApiUrl(root, 'settings', 'filters'); const res = await f(path); - return await res.json() as Array; + return await res.json() as FilterResponse; }, - async models(): Promise { + async models(): Promise { const path = makeApiUrl(root, 'settings', 'models'); const res = await f(path); - return await res.json() as ModelsResponse; + return await res.json() as ModelResponse; }, async noises(): Promise> { const path = makeApiUrl(root, 'settings', 'noises'); @@ -483,6 +490,10 @@ export function makeClient(root: string, f = fetch): ApiClient { url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); + if (doesExist(params.sourceFilter)) { + url.searchParams.append('sourceFilter', params.sourceFilter); + } + if (doesExist(upscale)) { appendUpscaleToURL(url, upscale); } diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 452a457d..5a0eb837 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -11,7 +11,7 @@ export class NoServerError extends BaseError { * @TODO client-side inference with https://www.npmjs.com/package/onnxruntime-web */ export const LOCAL_CLIENT = { - async masks() { + async filters() { throw new NoServerError(); }, async blend(model, params, upscale) { diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index a0010718..5d238a31 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -3,15 +3,16 @@ import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; -import { useMutation, useQueryClient } from '@tanstack/react-query'; +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useStore } from 'zustand'; -import { IMAGE_FILTER } from '../../config.js'; +import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; import { ClientContext, ConfigContext, StateContext } from '../../state.js'; import { ImageControl } from '../control/ImageControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { ImageInput } from '../input/ImageInput.js'; import { NumericField } from '../input/NumericField.js'; +import { QueryList } from '../input/QueryList.js'; export function Img2Img() { const { params } = mustExist(useContext(ConfigContext)); @@ -32,8 +33,14 @@ export function Img2Img() { onSuccess: () => query.invalidateQueries(['ready']), }); + const filters = useQuery(['filters'], async () => client.filters(), { + staleTime: STALE_TIME, + }); + + const state = mustExist(useContext(StateContext)); const source = useStore(state, (s) => s.img2img.source); + const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter); const strength = useStore(state, (s) => s.img2img.strength); // eslint-disable-next-line @typescript-eslint/unbound-method const setImg2Img = useStore(state, (s) => s.setImg2Img); @@ -49,19 +56,36 @@ export function Img2Img() { }); }} /> s.img2img} onChange={setImg2Img} /> - { - setImg2Img({ - strength: value, - }); - }} - /> + + f.source, + }} + value={sourceFilter} + onChange={(newFilter) => { + setImg2Img({ + sourceFilter: newFilter, + }); + }} + /> + { + setImg2Img({ + strength: value, + }); + }} + /> +