1
0
Fork 0

feat: add menu for source image filters

This commit is contained in:
Sean Sube 2023-04-13 20:59:26 -05:00
parent 80d00e4477
commit 4df28a5ce7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 75 additions and 26 deletions

View File

@ -108,7 +108,17 @@ def list_extra_strings(server: ServerContext):
return jsonify(get_extra_strings())
def list_filters(server: ServerContext):
mask_filters = list(get_mask_filters().keys())
source_filters = list(get_source_filters().keys())
return jsonify({
"mask": mask_filters,
"source": source_filters,
})
def list_mask_filters(server: ServerContext):
logger.info("dedicated list endpoint for mask filters is deprecated")
return jsonify(list(get_mask_filters().keys()))
@ -502,6 +512,7 @@ def status(server: ServerContext, pool: DevicePoolExecutor):
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, server, app=app)),
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)),
app.route("/api/settings/noises")(wrap_route(list_noise_sources, server)),

View File

@ -67,6 +67,8 @@ export interface Txt2ImgParams extends BaseImgParams {
*/
export interface Img2ImgParams extends BaseImgParams {
source: Blob;
sourceFilter?: string;
strength: number;
}
@ -201,10 +203,15 @@ export interface NetworkModel {
// TODO: add layer/token count
}
export interface FilterResponse {
mask: Array<string>;
source: Array<string>;
}
/**
* List of available models.
*/
export interface ModelsResponse {
export interface ModelResponse {
correction: Array<string>;
diffusion: Array<string>;
networks: Array<NetworkModel>;
@ -253,12 +260,12 @@ export interface ApiClient {
/**
* List the available filter masks for inpaint.
*/
masks(): Promise<Array<string>>;
filters(): Promise<FilterResponse>;
/**
* List the available models.
*/
models(): Promise<ModelsResponse>;
models(): Promise<ModelResponse>;
/**
* List the available noise sources for inpaint.
@ -433,15 +440,15 @@ export function makeClient(root: string, f = fetch): ApiClient {
}
return {
async masks(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'masks');
async filters(): Promise<FilterResponse> {
const path = makeApiUrl(root, 'settings', 'filters');
const res = await f(path);
return await res.json() as Array<string>;
return await res.json() as FilterResponse;
},
async models(): Promise<ModelsResponse> {
async models(): Promise<ModelResponse> {
const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path);
return await res.json() as ModelsResponse;
return await res.json() as ModelResponse;
},
async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises');
@ -483,6 +490,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(params.sourceFilter)) {
url.searchParams.append('sourceFilter', params.sourceFilter);
}
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}

View File

@ -11,7 +11,7 @@ export class NoServerError extends BaseError {
* @TODO client-side inference with https://www.npmjs.com/package/onnxruntime-web
*/
export const LOCAL_CLIENT = {
async masks() {
async filters() {
throw new NoServerError();
},
async blend(model, params, upscale) {

View File

@ -3,15 +3,16 @@ import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../../config.js';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { ImageControl } from '../control/ImageControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext));
@ -32,8 +33,14 @@ export function Img2Img() {
onSuccess: () => query.invalidateQueries(['ready']),
});
const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME,
});
const state = mustExist(useContext(StateContext));
const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
@ -49,19 +56,36 @@ export function Img2Img() {
});
}} />
<ImageControl selector={(s) => s.img2img} onChange={setImg2Img} />
<NumericField
decimal
label={t('parameter.strength')}
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
onChange={(value) => {
setImg2Img({
strength: value,
});
}}
/>
<Stack direction='row' spacing={2}>
<QueryList
id='sources'
labelKey={'sourceFilter'}
name={t('parameter.sourceFilter')}
query={{
result: filters,
selector: (f) => f.source,
}}
value={sourceFilter}
onChange={(newFilter) => {
setImg2Img({
sourceFilter: newFilter,
});
}}
/>
<NumericField
decimal
label={t('parameter.strength')}
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
onChange={(value) => {
setImg2Img({
strength: value,
});
}}
/>
</Stack>
<UpscaleControl />
<Button
disabled={doesExist(source) === false}

View File

@ -20,7 +20,7 @@ export function Inpaint() {
const { params } = mustExist(useContext(ConfigContext));
const client = mustExist(useContext(ClientContext));
const masks = useQuery(['masks'], async () => client.masks(), {
const filters = useQuery(['filters'], async () => client.filters(), {
staleTime: STALE_TIME,
});
const noises = useQuery(['noises'], async () => client.noises(), {
@ -146,7 +146,8 @@ export function Inpaint() {
labelKey={'maskFilter'}
name={t('parameter.maskFilter')}
query={{
result: masks,
result: filters,
selector: (f) => f.mask,
}}
value={filter}
onChange={(newFilter) => {

View File

@ -253,6 +253,7 @@ export function createStateSlices(server: ServerParams) {
img2img: {
...base,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
setImg2Img(params) {
@ -268,6 +269,7 @@ export function createStateSlices(server: ServerParams) {
img2img: {
...base,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
});