1
0
Fork 0

feat: add menu for source image filters

This commit is contained in:
Sean Sube 2023-04-13 20:59:26 -05:00
parent 80d00e4477
commit 4df28a5ce7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 75 additions and 26 deletions

View File

@ -108,7 +108,17 @@ def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings()) return jsonify(get_extra_strings())
def list_filters(server: ServerContext):
mask_filters = list(get_mask_filters().keys())
source_filters = list(get_source_filters().keys())
return jsonify({
"mask": mask_filters,
"source": source_filters,
})
def list_mask_filters(server: ServerContext): def list_mask_filters(server: ServerContext):
logger.info("dedicated list endpoint for mask filters is deprecated")
return jsonify(list(get_mask_filters().keys())) return jsonify(list(get_mask_filters().keys()))
@ -502,6 +512,7 @@ def status(server: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [ return [
app.route("/api")(wrap_route(introspect, server, app=app)), app.route("/api")(wrap_route(introspect, server, app=app)),
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)), app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)), app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),

View File

@ -67,6 +67,8 @@ export interface Txt2ImgParams extends BaseImgParams {
*/ */
export interface Img2ImgParams extends BaseImgParams { export interface Img2ImgParams extends BaseImgParams {
source: Blob; source: Blob;
sourceFilter?: string;
strength: number; strength: number;
} }
@ -201,10 +203,15 @@ export interface NetworkModel {
// TODO: add layer/token count // TODO: add layer/token count
} }
export interface FilterResponse {
mask: Array<string>;
source: Array<string>;
}
/** /**
* List of available models. * List of available models.
*/ */
export interface ModelsResponse { export interface ModelResponse {
correction: Array<string>; correction: Array<string>;
diffusion: Array<string>; diffusion: Array<string>;
networks: Array<NetworkModel>; networks: Array<NetworkModel>;
@ -253,12 +260,12 @@ export interface ApiClient {
/** /**
* List the available filter masks for inpaint. * List the available filter masks for inpaint.
*/ */
masks(): Promise<Array<string>>; filters(): Promise<FilterResponse>;
/** /**
* List the available models. * List the available models.
*/ */
models(): Promise<ModelsResponse>; models(): Promise<ModelResponse>;
/** /**
* List the available noise sources for inpaint. * List the available noise sources for inpaint.
@ -433,15 +440,15 @@ export function makeClient(root: string, f = fetch): ApiClient {
} }
return { return {
async masks(): Promise<Array<string>> { async filters(): Promise<FilterResponse> {
const path = makeApiUrl(root, 'settings', 'masks'); const path = makeApiUrl(root, 'settings', 'filters');
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as FilterResponse;
}, },
async models(): Promise<ModelsResponse> { async models(): Promise<ModelResponse> {
const path = makeApiUrl(root, 'settings', 'models'); const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path); const res = await f(path);
return await res.json() as ModelsResponse; return await res.json() as ModelResponse;
}, },
async noises(): Promise<Array<string>> { async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises'); const path = makeApiUrl(root, 'settings', 'noises');
@ -483,6 +490,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(params.sourceFilter)) {
url.searchParams.append('sourceFilter', params.sourceFilter);
}
if (doesExist(upscale)) { if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale); appendUpscaleToURL(url, upscale);
} }

View File

@ -11,7 +11,7 @@ export class NoServerError extends BaseError {
* @TODO client-side inference with https://www.npmjs.com/package/onnxruntime-web * @TODO client-side inference with https://www.npmjs.com/package/onnxruntime-web
*/ */
export const LOCAL_CLIENT = { export const LOCAL_CLIENT = {
async masks() { async filters() {
throw new NoServerError(); throw new NoServerError();
}, },
async blend(model, params, upscale) { async blend(model, params, upscale) {

View File

@ -3,15 +3,16 @@ import { Box, Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useMutation, useQueryClient } from '@tanstack/react-query'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../../config.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js'; import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { ImageControl } from '../control/ImageControl.js'; import { ImageControl } from '../control/ImageControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js'; import { ImageInput } from '../input/ImageInput.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
export function Img2Img() { export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext)); const { params } = mustExist(useContext(ConfigContext));
@ -32,8 +33,14 @@ export function Img2Img() {
onSuccess: () => query.invalidateQueries(['ready']), onSuccess: () => query.invalidateQueries(['ready']),
}); });
const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME,
});
const state = mustExist(useContext(StateContext)); const state = mustExist(useContext(StateContext));
const source = useStore(state, (s) => s.img2img.source); const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength); const strength = useStore(state, (s) => s.img2img.strength);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img); const setImg2Img = useStore(state, (s) => s.setImg2Img);
@ -49,6 +56,22 @@ export function Img2Img() {
}); });
}} /> }} />
<ImageControl selector={(s) => s.img2img} onChange={setImg2Img} /> <ImageControl selector={(s) => s.img2img} onChange={setImg2Img} />
<Stack direction='row' spacing={2}>
<QueryList
id='sources'
labelKey={'sourceFilter'}
name={t('parameter.sourceFilter')}
query={{
result: filters,
selector: (f) => f.source,
}}
value={sourceFilter}
onChange={(newFilter) => {
setImg2Img({
sourceFilter: newFilter,
});
}}
/>
<NumericField <NumericField
decimal decimal
label={t('parameter.strength')} label={t('parameter.strength')}
@ -62,6 +85,7 @@ export function Img2Img() {
}); });
}} }}
/> />
</Stack>
<UpscaleControl /> <UpscaleControl />
<Button <Button
disabled={doesExist(source) === false} disabled={doesExist(source) === false}

View File

@ -20,7 +20,7 @@ export function Inpaint() {
const { params } = mustExist(useContext(ConfigContext)); const { params } = mustExist(useContext(ConfigContext));
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const masks = useQuery(['masks'], async () => client.masks(), { const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const noises = useQuery(['noises'], async () => client.noises(), { const noises = useQuery(['noises'], async () => client.noises(), {
@ -146,7 +146,8 @@ export function Inpaint() {
labelKey={'maskFilter'} labelKey={'maskFilter'}
name={t('parameter.maskFilter')} name={t('parameter.maskFilter')}
query={{ query={{
result: masks, result: filters,
selector: (f) => f.mask,
}} }}
value={filter} value={filter}
onChange={(newFilter) => { onChange={(newFilter) => {

View File

@ -253,6 +253,7 @@ export function createStateSlices(server: ServerParams) {
img2img: { img2img: {
...base, ...base,
source: null, source: null,
sourceFilter: '',
strength: server.strength.default, strength: server.strength.default,
}, },
setImg2Img(params) { setImg2Img(params) {
@ -268,6 +269,7 @@ export function createStateSlices(server: ServerParams) {
img2img: { img2img: {
...base, ...base,
source: null, source: null,
sourceFilter: '',
strength: server.strength.default, strength: server.strength.default,
}, },
}); });