feat: add upscaling tab and endpoint
This commit is contained in:
parent
b7c85aa51b
commit
4aeee60b19
|
@ -6,7 +6,7 @@ from diffusers import (
|
|||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from os import environ
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageChops
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
@ -188,3 +188,20 @@ def run_inpaint_pipeline(
|
|||
image.save(dest)
|
||||
|
||||
print('saved inpaint output: %s' % (dest))
|
||||
|
||||
def run_upscale_pipeline(
|
||||
ctx: ServerContext,
|
||||
_params: BaseParams,
|
||||
_size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image,
|
||||
strength: float,
|
||||
):
|
||||
image = upscale_resrgan(ctx, upscale, source_image)
|
||||
image = ImageChops.blend(source_image, image, strength)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
||||
print('saved img2img output: %s' % (dest))
|
||||
|
|
|
@ -39,6 +39,7 @@ from .pipeline import (
|
|||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from .upscale import (
|
||||
UpscaleParams,
|
||||
|
@ -183,7 +184,8 @@ def upscale_from_request() -> UpscaleParams:
|
|||
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
|
||||
correction = get_from_list(request.args, 'correction', correction_models)
|
||||
faces = request.args.get('faces', 'false') == 'true'
|
||||
face_strength = get_and_clamp_float(request.args, 'faceStrength', 0.5, 1.0, 0.0)
|
||||
face_strength = get_and_clamp_float(
|
||||
request.args, 'faceStrength', 0.5, 1.0, 0.0)
|
||||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
|
@ -430,6 +432,38 @@ def inpaint():
|
|||
})
|
||||
|
||||
|
||||
@app.route('/api/upscale', methods=['POST'])
|
||||
def upscale():
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||
|
||||
params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
'strength',
|
||||
config_params.get('strength').get('default'),
|
||||
config_params.get('strength').get('max'))
|
||||
|
||||
output = make_output_name(
|
||||
'img2img',
|
||||
params,
|
||||
size,
|
||||
extras=(strength))
|
||||
print("img2img output: %s" % (output))
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(output, run_upscale_pipeline,
|
||||
context, params, output, upscale, source_image, strength)
|
||||
|
||||
return jsonify({
|
||||
'output': output,
|
||||
'params': params.tojson(),
|
||||
'size': upscale.resize(size).tojson(),
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/ready')
|
||||
def ready():
|
||||
output_file = request.args.get('output', None)
|
||||
|
|
|
@ -76,7 +76,12 @@ export interface UpscaleParams {
|
|||
faceStrength: number;
|
||||
}
|
||||
|
||||
export interface ApiResponse {
|
||||
export interface UpscaleReqParams {
|
||||
source: Blob;
|
||||
strength: number;
|
||||
}
|
||||
|
||||
export interface ImageResponse {
|
||||
output: {
|
||||
key: string;
|
||||
url: string;
|
||||
|
@ -88,11 +93,11 @@ export interface ApiResponse {
|
|||
};
|
||||
}
|
||||
|
||||
export interface ApiReady {
|
||||
export interface ReadyResponse {
|
||||
ready: boolean;
|
||||
}
|
||||
|
||||
export interface ApiModels {
|
||||
export interface ModelsResponse {
|
||||
diffusion: Array<string>;
|
||||
correction: Array<string>;
|
||||
upscaling: Array<string>;
|
||||
|
@ -100,18 +105,19 @@ export interface ApiModels {
|
|||
|
||||
export interface ApiClient {
|
||||
masks(): Promise<Array<string>>;
|
||||
models(): Promise<ApiModels>;
|
||||
models(): Promise<ModelsResponse>;
|
||||
noises(): Promise<Array<string>>;
|
||||
params(): Promise<ConfigParams>;
|
||||
platforms(): Promise<Array<string>>;
|
||||
schedulers(): Promise<Array<string>>;
|
||||
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
|
||||
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>;
|
||||
|
||||
ready(params: ApiResponse): Promise<ApiReady>;
|
||||
ready(params: ImageResponse): Promise<ReadyResponse>;
|
||||
}
|
||||
|
||||
export const STATUS_SUCCESS = 200;
|
||||
|
@ -130,7 +136,7 @@ export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams
|
|||
export const FIXED_INTEGER = 0;
|
||||
export const FIXED_FLOAT = 2;
|
||||
|
||||
export function equalResponse(a: ApiResponse, b: ApiResponse): boolean {
|
||||
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean {
|
||||
return a.output === b.output;
|
||||
}
|
||||
|
||||
|
@ -183,9 +189,9 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
|||
}
|
||||
|
||||
export function makeClient(root: string, f = fetch): ApiClient {
|
||||
let pending: Promise<ApiResponse> | undefined;
|
||||
let pending: Promise<ImageResponse> | undefined;
|
||||
|
||||
function throttleRequest(url: URL, options: RequestInit): Promise<ApiResponse> {
|
||||
function throttleRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
|
||||
return f(url, options).then((res) => parseApiResponse(root, res)).finally(() => {
|
||||
pending = undefined;
|
||||
});
|
||||
|
@ -197,10 +203,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async models(): Promise<ApiModels> {
|
||||
async models(): Promise<ModelsResponse> {
|
||||
const path = makeApiUrl(root, 'settings', 'models');
|
||||
const res = await f(path);
|
||||
return await res.json() as ApiModels;
|
||||
return await res.json() as ModelsResponse;
|
||||
},
|
||||
async noises(): Promise<Array<string>> {
|
||||
const path = makeApiUrl(root, 'settings', 'noises');
|
||||
|
@ -222,7 +228,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
@ -247,7 +253,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
@ -344,18 +350,41 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async ready(params: ApiResponse): Promise<ApiReady> {
|
||||
async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise<ImageResponse> {
|
||||
if (doesExist(pending)) {
|
||||
return pending;
|
||||
}
|
||||
|
||||
const url = makeApiUrl(root, 'upscale');
|
||||
appendModelToURL(url, model);
|
||||
|
||||
if (doesExist(upscale)) {
|
||||
appendUpscaleToURL(url, upscale);
|
||||
}
|
||||
|
||||
const body = new FormData();
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
pending = throttleRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
// eslint-disable-next-line no-return-await
|
||||
return await pending;
|
||||
},
|
||||
async ready(params: ImageResponse): Promise<ReadyResponse> {
|
||||
const path = makeApiUrl(root, 'ready');
|
||||
path.searchParams.append('output', params.output.key);
|
||||
|
||||
const res = await f(path);
|
||||
return await res.json() as ApiReady;
|
||||
return await res.json() as ReadyResponse;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
export async function parseApiResponse(root: string, res: Response): Promise<ApiResponse> {
|
||||
type LimitedResponse = Omit<ApiResponse, 'output'> & { output: string };
|
||||
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
|
||||
type LimitedResponse = Omit<ImageResponse, 'output'> & { output: string };
|
||||
|
||||
if (res.status === STATUS_SUCCESS) {
|
||||
const data = await res.json() as LimitedResponse;
|
||||
|
|
|
@ -5,13 +5,13 @@ import * as React from 'react';
|
|||
import { useContext } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse } from '../client.js';
|
||||
import { ImageResponse } from '../client.js';
|
||||
import { StateContext } from '../state.js';
|
||||
|
||||
export interface ImageCardProps {
|
||||
value: ApiResponse;
|
||||
value: ImageResponse;
|
||||
|
||||
onDelete?: (key: ApiResponse) => void;
|
||||
onDelete?: (key: ImageResponse) => void;
|
||||
}
|
||||
|
||||
export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
||||
|
|
|
@ -5,12 +5,12 @@ import { useContext } from 'react';
|
|||
import { useQuery } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ApiResponse } from '../client.js';
|
||||
import { ImageResponse } from '../client.js';
|
||||
import { POLL_TIME } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
|
||||
export interface LoadingCardProps {
|
||||
loading: ApiResponse;
|
||||
loading: ImageResponse;
|
||||
}
|
||||
|
||||
export function LoadingCard(props: LoadingCardProps) {
|
||||
|
|
|
@ -8,6 +8,7 @@ import { Inpaint } from './Inpaint.js';
|
|||
import { ModelControl } from './ModelControl.js';
|
||||
import { Settings } from './Settings.js';
|
||||
import { Txt2Img } from './Txt2Img.js';
|
||||
import { Upscale } from './Upscale.js';
|
||||
|
||||
const { useState } = React;
|
||||
|
||||
|
@ -32,6 +33,7 @@ export function OnnxWeb() {
|
|||
<Tab label='txt2img' value='txt2img' />
|
||||
<Tab label='img2img' value='img2img' />
|
||||
<Tab label='inpaint' value='inpaint' />
|
||||
<Tab label='upscale' value='upscale' />
|
||||
<Tab label='settings' value='settings' />
|
||||
</TabList>
|
||||
</Box>
|
||||
|
@ -44,6 +46,9 @@ export function OnnxWeb() {
|
|||
<TabPanel value='inpaint'>
|
||||
<Inpaint />
|
||||
</TabPanel>
|
||||
<TabPanel value='upscale'>
|
||||
<Upscale />
|
||||
</TabPanel>
|
||||
<TabPanel value='settings'>
|
||||
<Settings />
|
||||
</TabPanel>
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { IMAGE_FILTER } from '../config.js';
|
||||
import { ClientContext, ConfigContext, StateContext } from '../state.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
import { UpscaleControl } from './UpscaleControl.js';
|
||||
|
||||
const { useContext } = React;
|
||||
|
||||
export function Upscale() {
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
|
||||
async function uploadSource() {
|
||||
const { model, upscale } = state.getState();
|
||||
|
||||
const output = await client.upscale(model, {
|
||||
...params,
|
||||
source: mustExist(params.source), // TODO: show an error if this doesn't exist
|
||||
}, upscale);
|
||||
|
||||
setLoading(output);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const query = useQueryClient();
|
||||
const upload = useMutation(uploadSource, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
|
||||
});
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.upscaleTab);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setSource = useStore(state, (s) => s.setUpscaleTab);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setLoading = useStore(state, (s) => s.setLoading);
|
||||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
|
||||
setSource({
|
||||
source: file,
|
||||
});
|
||||
}} />
|
||||
<NumericField
|
||||
decimal
|
||||
label='Strength'
|
||||
min={config.strength.min}
|
||||
max={config.strength.max}
|
||||
step={config.strength.step}
|
||||
value={params.strength}
|
||||
onChange={(value) => {
|
||||
setSource({
|
||||
strength: value,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<UpscaleControl config={config} />
|
||||
<Button onClick={() => upload.mutate()}>Generate</Button>
|
||||
</Stack>
|
||||
</Box>;
|
||||
}
|
|
@ -92,7 +92,7 @@ export function UpscaleControl(props: UpscaleControlProps) {
|
|||
<NumericField
|
||||
label='Strength'
|
||||
decimal
|
||||
disabled={params.enabled === false}
|
||||
disabled={params.enabled === false && params.faces}
|
||||
min={config.faceStrength.min}
|
||||
max={config.faceStrength.max}
|
||||
step={config.faceStrength.step}
|
||||
|
|
|
@ -5,7 +5,7 @@ import { StateCreator, StoreApi } from 'zustand';
|
|||
|
||||
import {
|
||||
ApiClient,
|
||||
ApiResponse,
|
||||
ImageResponse,
|
||||
BaseImgParams,
|
||||
BrushParams,
|
||||
Img2ImgParams,
|
||||
|
@ -15,10 +15,11 @@ import {
|
|||
paramsFromConfig,
|
||||
Txt2ImgParams,
|
||||
UpscaleParams,
|
||||
UpscaleReqParams,
|
||||
} from './client.js';
|
||||
import { ConfigFiles, ConfigParams, ConfigState } from './config.js';
|
||||
|
||||
type TabState<TabParams extends BaseImgParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||
type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||
|
||||
interface Txt2ImgSlice {
|
||||
txt2img: TabState<Txt2ImgParams>;
|
||||
|
@ -42,14 +43,14 @@ interface InpaintSlice {
|
|||
}
|
||||
|
||||
interface HistorySlice {
|
||||
history: Array<ApiResponse>;
|
||||
history: Array<ImageResponse>;
|
||||
limit: number;
|
||||
loading: Maybe<ApiResponse>;
|
||||
loading: Maybe<ImageResponse>;
|
||||
|
||||
pushHistory(image: ApiResponse): void;
|
||||
removeHistory(image: ApiResponse): void;
|
||||
pushHistory(image: ImageResponse): void;
|
||||
removeHistory(image: ImageResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setLoading(image: Maybe<ApiResponse>): void;
|
||||
setLoading(image: Maybe<ImageResponse>): void;
|
||||
}
|
||||
|
||||
interface DefaultSlice {
|
||||
|
@ -72,8 +73,11 @@ interface BrushSlice {
|
|||
|
||||
interface UpscaleSlice {
|
||||
upscale: UpscaleParams;
|
||||
upscaleTab: TabState<UpscaleReqParams>;
|
||||
|
||||
setUpscale(upscale: Partial<UpscaleParams>): void;
|
||||
setUpscaleTab(params: Partial<UpscaleReqParams>): void;
|
||||
resetUpscaleTab(): void;
|
||||
}
|
||||
|
||||
interface ModelSlice {
|
||||
|
@ -252,14 +256,34 @@ export function createStateSlices(base: ConfigParams) {
|
|||
outscale: 1,
|
||||
faceStrength: 0.5,
|
||||
},
|
||||
upscaleTab: {
|
||||
source: null,
|
||||
strength: 1.0,
|
||||
},
|
||||
setUpscale(upscale) {
|
||||
set((prev) => ({
|
||||
upscale: {
|
||||
...prev.upscale,
|
||||
...upscale,
|
||||
}
|
||||
},
|
||||
}));
|
||||
},
|
||||
setUpscaleTab(source) {
|
||||
set((prev) => ({
|
||||
upscaleTab: {
|
||||
...prev.upscaleTab,
|
||||
...source,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetUpscaleTab() {
|
||||
set({
|
||||
upscaleTab: {
|
||||
source: null,
|
||||
strength: 1.0,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({
|
||||
|
|
Loading…
Reference in New Issue