1
0
Fork 0

feat: add upscaling tab and endpoint

This commit is contained in:
Sean Sube 2023-01-16 23:45:54 -06:00
parent b7c85aa51b
commit 4aeee60b19
9 changed files with 211 additions and 36 deletions

View File

@ -6,7 +6,7 @@ from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from os import environ
from PIL import Image
from PIL import Image, ImageChops
from typing import Any
import numpy as np
@ -188,3 +188,20 @@ def run_inpaint_pipeline(
image.save(dest)
print('saved inpaint output: %s' % (dest))
def run_upscale_pipeline(
ctx: ServerContext,
_params: BaseParams,
_size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image,
strength: float,
):
image = upscale_resrgan(ctx, upscale, source_image)
image = ImageChops.blend(source_image, image, strength)
dest = safer_join(ctx.output_path, output)
image.save(dest)
print('saved img2img output: %s' % (dest))

View File

@ -39,6 +39,7 @@ from .pipeline import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .upscale import (
UpscaleParams,
@ -183,7 +184,8 @@ def upscale_from_request() -> UpscaleParams:
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = request.args.get('faces', 'false') == 'true'
face_strength = get_and_clamp_float(request.args, 'faceStrength', 0.5, 1.0, 0.0)
face_strength = get_and_clamp_float(
request.args, 'faceStrength', 0.5, 1.0, 0.0)
return UpscaleParams(
upscaling,
@ -430,6 +432,38 @@ def inpaint():
})
@app.route('/api/upscale', methods=['POST'])
def upscale():
source_file = request.files.get('source')
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
params, size = pipeline_from_request()
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,
'strength',
config_params.get('strength').get('default'),
config_params.get('strength').get('max'))
output = make_output_name(
'img2img',
params,
size,
extras=(strength))
print("img2img output: %s" % (output))
source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_upscale_pipeline,
context, params, output, upscale, source_image, strength)
return jsonify({
'output': output,
'params': params.tojson(),
'size': upscale.resize(size).tojson(),
})
@app.route('/api/ready')
def ready():
output_file = request.args.get('output', None)

View File

@ -76,7 +76,12 @@ export interface UpscaleParams {
faceStrength: number;
}
export interface ApiResponse {
export interface UpscaleReqParams {
source: Blob;
strength: number;
}
export interface ImageResponse {
output: {
key: string;
url: string;
@ -88,11 +93,11 @@ export interface ApiResponse {
};
}
export interface ApiReady {
export interface ReadyResponse {
ready: boolean;
}
export interface ApiModels {
export interface ModelsResponse {
diffusion: Array<string>;
correction: Array<string>;
upscaling: Array<string>;
@ -100,18 +105,19 @@ export interface ApiModels {
export interface ApiClient {
masks(): Promise<Array<string>>;
models(): Promise<ApiModels>;
models(): Promise<ModelsResponse>;
noises(): Promise<Array<string>>;
params(): Promise<ConfigParams>;
platforms(): Promise<Array<string>>;
schedulers(): Promise<Array<string>>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ApiResponse>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise<ImageResponse>;
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise<ImageResponse>;
ready(params: ApiResponse): Promise<ApiReady>;
ready(params: ImageResponse): Promise<ReadyResponse>;
}
export const STATUS_SUCCESS = 200;
@ -130,7 +136,7 @@ export function paramsFromConfig(defaults: ConfigParams): Required<BaseImgParams
export const FIXED_INTEGER = 0;
export const FIXED_FLOAT = 2;
export function equalResponse(a: ApiResponse, b: ApiResponse): boolean {
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean {
return a.output === b.output;
}
@ -183,9 +189,9 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
}
export function makeClient(root: string, f = fetch): ApiClient {
let pending: Promise<ApiResponse> | undefined;
let pending: Promise<ImageResponse> | undefined;
function throttleRequest(url: URL, options: RequestInit): Promise<ApiResponse> {
function throttleRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
return f(url, options).then((res) => parseApiResponse(root, res)).finally(() => {
pending = undefined;
});
@ -197,10 +203,10 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async models(): Promise<ApiModels> {
async models(): Promise<ModelsResponse> {
const path = makeApiUrl(root, 'settings', 'models');
const res = await f(path);
return await res.json() as ApiModels;
return await res.json() as ModelsResponse;
},
async noises(): Promise<Array<string>> {
const path = makeApiUrl(root, 'settings', 'noises');
@ -222,7 +228,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
if (doesExist(pending)) {
return pending;
}
@ -247,7 +253,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ApiResponse> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise<ImageResponse> {
if (doesExist(pending)) {
return pending;
}
@ -344,18 +350,41 @@ export function makeClient(root: string, f = fetch): ApiClient {
// eslint-disable-next-line no-return-await
return await pending;
},
async ready(params: ApiResponse): Promise<ApiReady> {
async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise<ImageResponse> {
if (doesExist(pending)) {
return pending;
}
const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model);
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
const body = new FormData();
body.append('source', params.source, 'source');
pending = throttleRequest(url, {
body,
method: 'POST',
});
// eslint-disable-next-line no-return-await
return await pending;
},
async ready(params: ImageResponse): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
path.searchParams.append('output', params.output.key);
const res = await f(path);
return await res.json() as ApiReady;
return await res.json() as ReadyResponse;
}
};
}
export async function parseApiResponse(root: string, res: Response): Promise<ApiResponse> {
type LimitedResponse = Omit<ApiResponse, 'output'> & { output: string };
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
type LimitedResponse = Omit<ImageResponse, 'output'> & { output: string };
if (res.status === STATUS_SUCCESS) {
const data = await res.json() as LimitedResponse;

View File

@ -5,13 +5,13 @@ import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand';
import { ApiResponse } from '../client.js';
import { ImageResponse } from '../client.js';
import { StateContext } from '../state.js';
export interface ImageCardProps {
value: ApiResponse;
value: ImageResponse;
onDelete?: (key: ApiResponse) => void;
onDelete?: (key: ImageResponse) => void;
}
export function GridItem(props: { xs: number; children: React.ReactNode }) {

View File

@ -5,12 +5,12 @@ import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';
import { ApiResponse } from '../client.js';
import { ImageResponse } from '../client.js';
import { POLL_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';
export interface LoadingCardProps {
loading: ApiResponse;
loading: ImageResponse;
}
export function LoadingCard(props: LoadingCardProps) {

View File

@ -8,6 +8,7 @@ import { Inpaint } from './Inpaint.js';
import { ModelControl } from './ModelControl.js';
import { Settings } from './Settings.js';
import { Txt2Img } from './Txt2Img.js';
import { Upscale } from './Upscale.js';
const { useState } = React;
@ -32,6 +33,7 @@ export function OnnxWeb() {
<Tab label='txt2img' value='txt2img' />
<Tab label='img2img' value='img2img' />
<Tab label='inpaint' value='inpaint' />
<Tab label='upscale' value='upscale' />
<Tab label='settings' value='settings' />
</TabList>
</Box>
@ -44,6 +46,9 @@ export function OnnxWeb() {
<TabPanel value='inpaint'>
<Inpaint />
</TabPanel>
<TabPanel value='upscale'>
<Upscale />
</TabPanel>
<TabPanel value='settings'>
<Settings />
</TabPanel>

View File

@ -0,0 +1,66 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../config.js';
import { ClientContext, ConfigContext, StateContext } from '../state.js';
import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js';
import { UpscaleControl } from './UpscaleControl.js';
const { useContext } = React;
export function Upscale() {
const config = mustExist(useContext(ConfigContext));
async function uploadSource() {
const { model, upscale } = state.getState();
const output = await client.upscale(model, {
...params,
source: mustExist(params.source), // TODO: show an error if this doesn't exist
}, upscale);
setLoading(output);
}
const client = mustExist(useContext(ClientContext));
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.upscaleTab);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setSource = useStore(state, (s) => s.setUpscaleTab);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
return <Box>
<Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
setSource({
source: file,
});
}} />
<NumericField
decimal
label='Strength'
min={config.strength.min}
max={config.strength.max}
step={config.strength.step}
value={params.strength}
onChange={(value) => {
setSource({
strength: value,
});
}}
/>
<UpscaleControl config={config} />
<Button onClick={() => upload.mutate()}>Generate</Button>
</Stack>
</Box>;
}

View File

@ -92,7 +92,7 @@ export function UpscaleControl(props: UpscaleControlProps) {
<NumericField
label='Strength'
decimal
disabled={params.enabled === false}
disabled={params.enabled === false && params.faces}
min={config.faceStrength.min}
max={config.faceStrength.max}
step={config.faceStrength.step}

View File

@ -5,7 +5,7 @@ import { StateCreator, StoreApi } from 'zustand';
import {
ApiClient,
ApiResponse,
ImageResponse,
BaseImgParams,
BrushParams,
Img2ImgParams,
@ -15,10 +15,11 @@ import {
paramsFromConfig,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
} from './client.js';
import { ConfigFiles, ConfigParams, ConfigState } from './config.js';
type TabState<TabParams extends BaseImgParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>;
@ -42,14 +43,14 @@ interface InpaintSlice {
}
interface HistorySlice {
history: Array<ApiResponse>;
history: Array<ImageResponse>;
limit: number;
loading: Maybe<ApiResponse>;
loading: Maybe<ImageResponse>;
pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void;
pushHistory(image: ImageResponse): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setLoading(image: Maybe<ApiResponse>): void;
setLoading(image: Maybe<ImageResponse>): void;
}
interface DefaultSlice {
@ -72,8 +73,11 @@ interface BrushSlice {
interface UpscaleSlice {
upscale: UpscaleParams;
upscaleTab: TabState<UpscaleReqParams>;
setUpscale(upscale: Partial<UpscaleParams>): void;
setUpscaleTab(params: Partial<UpscaleReqParams>): void;
resetUpscaleTab(): void;
}
interface ModelSlice {
@ -252,14 +256,34 @@ export function createStateSlices(base: ConfigParams) {
outscale: 1,
faceStrength: 0.5,
},
upscaleTab: {
source: null,
strength: 1.0,
},
setUpscale(upscale) {
set((prev) => ({
upscale: {
...prev.upscale,
...upscale,
}
},
}));
},
setUpscaleTab(source) {
set((prev) => ({
upscaleTab: {
...prev.upscaleTab,
...source,
},
}));
},
resetUpscaleTab() {
set({
upscaleTab: {
source: null,
strength: 1.0,
},
});
},
});
const createDefaultSlice: StateCreator<OnnxState, [], [], DefaultSlice> = (set) => ({