diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index ba0da1cc..de8d1427 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -6,7 +6,7 @@ from diffusers import ( OnnxStableDiffusionInpaintPipeline, ) from os import environ -from PIL import Image +from PIL import Image, ImageChops from typing import Any import numpy as np @@ -188,3 +188,20 @@ def run_inpaint_pipeline( image.save(dest) print('saved inpaint output: %s' % (dest)) + +def run_upscale_pipeline( + ctx: ServerContext, + _params: BaseParams, + _size: Size, + output: str, + upscale: UpscaleParams, + source_image: Image, + strength: float, +): + image = upscale_resrgan(ctx, upscale, source_image) + image = ImageChops.blend(source_image, image, strength) + + dest = safer_join(ctx.output_path, output) + image.save(dest) + + print('saved img2img output: %s' % (dest)) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 568f356b..e88c01ab 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -39,6 +39,7 @@ from .pipeline import ( run_img2img_pipeline, run_inpaint_pipeline, run_txt2img_pipeline, + run_upscale_pipeline, ) from .upscale import ( UpscaleParams, @@ -183,7 +184,8 @@ def upscale_from_request() -> UpscaleParams: upscaling = get_from_list(request.args, 'upscaling', upscaling_models) correction = get_from_list(request.args, 'correction', correction_models) faces = request.args.get('faces', 'false') == 'true' - face_strength = get_and_clamp_float(request.args, 'faceStrength', 0.5, 1.0, 0.0) + face_strength = get_and_clamp_float( + request.args, 'faceStrength', 0.5, 1.0, 0.0) return UpscaleParams( upscaling, @@ -430,6 +432,38 @@ def inpaint(): }) +@app.route('/api/upscale', methods=['POST']) +def upscale(): + source_file = request.files.get('source') + source_image = Image.open(BytesIO(source_file.read())).convert('RGB') + + params, size = pipeline_from_request() + upscale = upscale_from_request() + + strength = get_and_clamp_float( + request.args, + 'strength', + config_params.get('strength').get('default'), + config_params.get('strength').get('max')) + + output = make_output_name( + 'img2img', + params, + size, + extras=(strength)) + print("img2img output: %s" % (output)) + + source_image.thumbnail((size.width, size.height)) + executor.submit_stored(output, run_upscale_pipeline, + context, params, output, upscale, source_image, strength) + + return jsonify({ + 'output': output, + 'params': params.tojson(), + 'size': upscale.resize(size).tojson(), + }) + + @app.route('/api/ready') def ready(): output_file = request.args.get('output', None) diff --git a/gui/src/client.ts b/gui/src/client.ts index b432145c..61d4a6c3 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -76,7 +76,12 @@ export interface UpscaleParams { faceStrength: number; } -export interface ApiResponse { +export interface UpscaleReqParams { + source: Blob; + strength: number; +} + +export interface ImageResponse { output: { key: string; url: string; @@ -88,11 +93,11 @@ export interface ApiResponse { }; } -export interface ApiReady { +export interface ReadyResponse { ready: boolean; } -export interface ApiModels { +export interface ModelsResponse { diffusion: Array; correction: Array; upscaling: Array; @@ -100,18 +105,19 @@ export interface ApiModels { export interface ApiClient { masks(): Promise>; - models(): Promise; + models(): Promise; noises(): Promise>; params(): Promise; platforms(): Promise>; schedulers(): Promise>; - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; - txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise; - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise; - ready(params: ApiResponse): Promise; + ready(params: ImageResponse): Promise; } export const STATUS_SUCCESS = 200; @@ -130,7 +136,7 @@ export function paramsFromConfig(defaults: ConfigParams): Required | undefined; + let pending: Promise | undefined; - function throttleRequest(url: URL, options: RequestInit): Promise { + function throttleRequest(url: URL, options: RequestInit): Promise { return f(url, options).then((res) => parseApiResponse(root, res)).finally(() => { pending = undefined; }); @@ -197,10 +203,10 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, - async models(): Promise { + async models(): Promise { const path = makeApiUrl(root, 'settings', 'models'); const res = await f(path); - return await res.json() as ApiModels; + return await res.json() as ModelsResponse; }, async noises(): Promise> { const path = makeApiUrl(root, 'settings', 'noises'); @@ -222,7 +228,7 @@ export function makeClient(root: string, f = fetch): ApiClient { const res = await f(path); return await res.json() as Array; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } @@ -247,7 +253,7 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise { + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams): Promise { if (doesExist(pending)) { return pending; } @@ -344,18 +350,41 @@ export function makeClient(root: string, f = fetch): ApiClient { // eslint-disable-next-line no-return-await return await pending; }, - async ready(params: ApiResponse): Promise { + async upscale(model: ModelParams, params: UpscaleReqParams, upscale: UpscaleParams): Promise { + if (doesExist(pending)) { + return pending; + } + + const url = makeApiUrl(root, 'upscale'); + appendModelToURL(url, model); + + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); + } + + const body = new FormData(); + body.append('source', params.source, 'source'); + + pending = throttleRequest(url, { + body, + method: 'POST', + }); + + // eslint-disable-next-line no-return-await + return await pending; + }, + async ready(params: ImageResponse): Promise { const path = makeApiUrl(root, 'ready'); path.searchParams.append('output', params.output.key); const res = await f(path); - return await res.json() as ApiReady; + return await res.json() as ReadyResponse; } }; } -export async function parseApiResponse(root: string, res: Response): Promise { - type LimitedResponse = Omit & { output: string }; +export async function parseApiResponse(root: string, res: Response): Promise { + type LimitedResponse = Omit & { output: string }; if (res.status === STATUS_SUCCESS) { const data = await res.json() as LimitedResponse; diff --git a/gui/src/components/ImageCard.tsx b/gui/src/components/ImageCard.tsx index b29ff395..521e1c95 100644 --- a/gui/src/components/ImageCard.tsx +++ b/gui/src/components/ImageCard.tsx @@ -5,13 +5,13 @@ import * as React from 'react'; import { useContext } from 'react'; import { useStore } from 'zustand'; -import { ApiResponse } from '../client.js'; +import { ImageResponse } from '../client.js'; import { StateContext } from '../state.js'; export interface ImageCardProps { - value: ApiResponse; + value: ImageResponse; - onDelete?: (key: ApiResponse) => void; + onDelete?: (key: ImageResponse) => void; } export function GridItem(props: { xs: number; children: React.ReactNode }) { diff --git a/gui/src/components/LoadingCard.tsx b/gui/src/components/LoadingCard.tsx index 7fa48bd0..1656ffc6 100644 --- a/gui/src/components/LoadingCard.tsx +++ b/gui/src/components/LoadingCard.tsx @@ -5,12 +5,12 @@ import { useContext } from 'react'; import { useQuery } from 'react-query'; import { useStore } from 'zustand'; -import { ApiResponse } from '../client.js'; +import { ImageResponse } from '../client.js'; import { POLL_TIME } from '../config.js'; import { ClientContext, StateContext } from '../state.js'; export interface LoadingCardProps { - loading: ApiResponse; + loading: ImageResponse; } export function LoadingCard(props: LoadingCardProps) { diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index b71ccbb4..9b49e072 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -8,6 +8,7 @@ import { Inpaint } from './Inpaint.js'; import { ModelControl } from './ModelControl.js'; import { Settings } from './Settings.js'; import { Txt2Img } from './Txt2Img.js'; +import { Upscale } from './Upscale.js'; const { useState } = React; @@ -32,6 +33,7 @@ export function OnnxWeb() { + @@ -44,6 +46,9 @@ export function OnnxWeb() { + + + diff --git a/gui/src/components/Upscale.tsx b/gui/src/components/Upscale.tsx new file mode 100644 index 00000000..b0f4bc89 --- /dev/null +++ b/gui/src/components/Upscale.tsx @@ -0,0 +1,66 @@ +import { mustExist } from '@apextoaster/js-utils'; +import { Box, Button, Stack } from '@mui/material'; +import * as React from 'react'; +import { useMutation, useQueryClient } from 'react-query'; +import { useStore } from 'zustand'; + +import { IMAGE_FILTER } from '../config.js'; +import { ClientContext, ConfigContext, StateContext } from '../state.js'; +import { ImageInput } from './ImageInput.js'; +import { NumericField } from './NumericField.js'; +import { UpscaleControl } from './UpscaleControl.js'; + +const { useContext } = React; + +export function Upscale() { + const config = mustExist(useContext(ConfigContext)); + + async function uploadSource() { + const { model, upscale } = state.getState(); + + const output = await client.upscale(model, { + ...params, + source: mustExist(params.source), // TODO: show an error if this doesn't exist + }, upscale); + + setLoading(output); + } + + const client = mustExist(useContext(ClientContext)); + const query = useQueryClient(); + const upload = useMutation(uploadSource, { + onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }), + }); + + const state = mustExist(useContext(StateContext)); + const params = useStore(state, (s) => s.upscaleTab); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setSource = useStore(state, (s) => s.setUpscaleTab); + // eslint-disable-next-line @typescript-eslint/unbound-method + const setLoading = useStore(state, (s) => s.setLoading); + + return + + { + setSource({ + source: file, + }); + }} /> + { + setSource({ + strength: value, + }); + }} + /> + + + + ; +} diff --git a/gui/src/components/UpscaleControl.tsx b/gui/src/components/UpscaleControl.tsx index 14866bc8..b8e32ccd 100644 --- a/gui/src/components/UpscaleControl.tsx +++ b/gui/src/components/UpscaleControl.tsx @@ -92,7 +92,7 @@ export function UpscaleControl(props: UpscaleControlProps) { = ConfigFiles> & ConfigState>; +type TabState = ConfigFiles> & ConfigState>; interface Txt2ImgSlice { txt2img: TabState; @@ -42,14 +43,14 @@ interface InpaintSlice { } interface HistorySlice { - history: Array; + history: Array; limit: number; - loading: Maybe; + loading: Maybe; - pushHistory(image: ApiResponse): void; - removeHistory(image: ApiResponse): void; + pushHistory(image: ImageResponse): void; + removeHistory(image: ImageResponse): void; setLimit(limit: number): void; - setLoading(image: Maybe): void; + setLoading(image: Maybe): void; } interface DefaultSlice { @@ -72,8 +73,11 @@ interface BrushSlice { interface UpscaleSlice { upscale: UpscaleParams; + upscaleTab: TabState; setUpscale(upscale: Partial): void; + setUpscaleTab(params: Partial): void; + resetUpscaleTab(): void; } interface ModelSlice { @@ -252,14 +256,34 @@ export function createStateSlices(base: ConfigParams) { outscale: 1, faceStrength: 0.5, }, + upscaleTab: { + source: null, + strength: 1.0, + }, setUpscale(upscale) { set((prev) => ({ upscale: { ...prev.upscale, ...upscale, - } + }, })); }, + setUpscaleTab(source) { + set((prev) => ({ + upscaleTab: { + ...prev.upscaleTab, + ...source, + }, + })); + }, + resetUpscaleTab() { + set({ + upscaleTab: { + source: null, + strength: 1.0, + }, + }); + }, }); const createDefaultSlice: StateCreator = (set) => ({