diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 72fee1f7..f5fe372f 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -42,6 +42,9 @@ def run_highres( inversions: List[Tuple[str, float]], loras: List[Tuple[str, float]], ) -> None: + if highres.scale <= 1: + return image + highres_progress = ChainProgress.from_progress(progress) if upscale.faces and ( @@ -213,19 +216,18 @@ def run_txt2img_pipeline( del pipe for image, output in image_outputs: - if highres.scale > 1: - image = run_highres( - job, - server, - params, - size, - upscale, - highres, - image, - progress, - inversions, - loras, - ) + image = run_highres( + job, + server, + params, + size, + upscale, + highres, + image, + progress, + inversions, + loras, + ) image = run_upscale_correction( job, @@ -321,19 +323,18 @@ def run_img2img_pipeline( images.append(source) for image, output in zip(images, outputs): - if highres.scale > 1: - image = run_highres( - job, - server, - params, - Size(source.width, source.height), - upscale, - highres, - image, - progress, - inversions, - loras, - ) + image = run_highres( + job, + server, + params, + Size(source.width, source.height), + upscale, + highres, + image, + progress, + inversions, + loras, + ) image = run_upscale_correction( job, @@ -373,6 +374,10 @@ def run_inpaint_pipeline( progress = job.get_progress_callback() stage = StageParams(tile_order=tile_order) + (prompt, loras) = get_loras_from_prompt(params.prompt) + (prompt, inversions) = get_inversions_from_prompt(prompt) + params.prompt = prompt + # calling the upscale_outpaint stage directly needs accumulating progress progress = ChainProgress.from_progress(progress) @@ -391,6 +396,19 @@ def run_inpaint_pipeline( callback=progress, ) + image = run_highres( + job, + server, + params, + size, + upscale, + highres, + image, + progress, + inversions, + loras, + ) + image = run_upscale_correction( job, server, @@ -424,10 +442,28 @@ def run_upscale_pipeline( progress = job.get_progress_callback() stage = StageParams() + (prompt, loras) = get_loras_from_prompt(params.prompt) + (prompt, inversions) = get_inversions_from_prompt(prompt) + params.prompt = prompt + image = run_upscale_correction( job, server, stage, params, source, upscale=upscale, callback=progress ) + # TODO: should this come first? + image = run_highres( + job, + server, + params, + size, + upscale, + highres, + image, + progress, + inversions, + loras, + ) + dest = save_image(server, outputs[0], image) save_params(server, outputs[0], params, size, upscale=upscale) @@ -463,6 +499,20 @@ def run_blend_pipeline( ) image = image.convert("RGB") + # TODO: blend tab doesn't have a prompt + image = run_highres( + job, + server, + params, + size, + upscale, + highres, + image, + progress, + [], + [], + ) + image = run_upscale_correction( job, server, stage, params, image, upscale=upscale, callback=progress ) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index ebbfc984..46c05826 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -164,6 +164,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server, "img2img") upscale = upscale_from_request() + highres = highres_from_request() source_filter = get_from_list( request.args, "sourceFilter", list(get_source_filters().keys()) ) @@ -195,13 +196,14 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): params, output, upscale, + highres, source, strength, needs_device=device, source_filter=source_filter, ) - return jsonify(json_params(output, params, size, upscale=upscale)) + return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) def txt2img(server: ServerContext, pool: DevicePoolExecutor): @@ -243,6 +245,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server, "inpaint") expand = border_from_request() upscale = upscale_from_request() + highres = highres_from_request() fill_color = get_not_empty(request.args, "fillColor", "white") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") @@ -280,6 +283,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): size, output, upscale, + highres, source, mask, expand, @@ -290,7 +294,11 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): needs_device=device, ) - return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) + return jsonify( + json_params( + output, params, size, upscale=upscale, border=expand, highres=highres + ) + ) def upscale(server: ServerContext, pool: DevicePoolExecutor): @@ -302,6 +310,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server) upscale = upscale_from_request() + highres = highres_from_request() output = make_output_name(server, "upscale", params, size) job_name = output[0] @@ -316,11 +325,12 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): size, output, upscale, + highres, source, needs_device=device, ) - return jsonify(json_params(output, params, size, upscale=upscale)) + return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) def chain(server: ServerContext, pool: DevicePoolExecutor): diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 08030df2..a49d8adf 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -229,21 +229,25 @@ export type RetryParams = { model: ModelParams; params: Img2ImgParams; upscale?: UpscaleParams; + highres?: HighresParams; } | { type: 'inpaint'; model: ModelParams; params: InpaintParams; upscale?: UpscaleParams; + highres?: HighresParams; } | { type: 'outpaint'; model: ModelParams; params: OutpaintParams; upscale?: UpscaleParams; + highres?: HighresParams; } | { type: 'upscale'; model: ModelParams; params: UpscaleReqParams; upscale?: UpscaleParams; + highres?: HighresParams; } | { type: 'blend'; model: ModelParams; @@ -307,22 +311,22 @@ export interface ApiClient { /** * Start an im2img pipeline. */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an inpaint pipeline. */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an outpaint pipeline. */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an upscale pipeline. */ - upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise; + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start a blending pipeline. @@ -431,6 +435,16 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { } } +export function appendHighresToURL(url: URL, highres: HighresParams) { + if (highres.enabled) { + url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER)); + url.searchParams.append('highresMethod', highres.highresMethod); + url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER)); + url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER)); + url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT)); + } +} + /** * Make an API client using the given API root and fetch client. */ @@ -484,7 +498,7 @@ export function makeClient(root: string, f = fetch): ApiClient { translation: Record; }>; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); @@ -498,6 +512,10 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } + if (doesExist(highres)) { + appendHighresToURL(url, highres); + } + const body = new FormData(); body.append('source', params.source, 'source'); @@ -531,12 +549,8 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } - if (doesExist(highres) && highres.enabled) { - url.searchParams.append('highresIterations', highres.highresIterations.toFixed(FIXED_INTEGER)); - url.searchParams.append('highresMethod', highres.highresMethod); - url.searchParams.append('highresScale', highres.highresScale.toFixed(FIXED_INTEGER)); - url.searchParams.append('highresSteps', highres.highresSteps.toFixed(FIXED_INTEGER)); - url.searchParams.append('highresStrength', highres.highresStrength.toFixed(FIXED_FLOAT)); + if (doesExist(highres)) { + appendHighresToURL(url, highres); } const image = await parseRequest(url, { @@ -553,7 +567,7 @@ export function makeClient(root: string, f = fetch): ApiClient { }, }; }, - async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams): Promise { + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -566,6 +580,10 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } + if (doesExist(highres)) { + appendHighresToURL(url, highres); + } + const body = new FormData(); body.append('mask', params.mask, 'mask'); body.append('source', params.source, 'source'); @@ -584,7 +602,7 @@ export function makeClient(root: string, f = fetch): ApiClient { }, }; }, - async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams): Promise { + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -598,6 +616,10 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } + if (doesExist(highres)) { + appendHighresToURL(url, highres); + } + if (doesExist(params.left)) { url.searchParams.append('left', params.left.toFixed(FIXED_INTEGER)); } @@ -632,7 +654,7 @@ export function makeClient(root: string, f = fetch): ApiClient { }, }; }, - async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams): Promise { + async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeApiUrl(root, 'upscale'); appendModelToURL(url, model); @@ -640,6 +662,10 @@ export function makeClient(root: string, f = fetch): ApiClient { appendUpscaleToURL(url, upscale); } + if (doesExist(highres)) { + appendHighresToURL(url, highres); + } + url.searchParams.append('prompt', params.prompt); if (doesExist(params.negativePrompt)) { @@ -714,15 +740,15 @@ export function makeClient(root: string, f = fetch): ApiClient { case 'blend': return this.blend(retry.model, retry.params, retry.upscale); case 'img2img': - return this.img2img(retry.model, retry.params, retry.upscale); + return this.img2img(retry.model, retry.params, retry.upscale, retry.highres); case 'inpaint': - return this.inpaint(retry.model, retry.params, retry.upscale); + return this.inpaint(retry.model, retry.params, retry.upscale, retry.highres); case 'outpaint': - return this.outpaint(retry.model, retry.params, retry.upscale); + return this.outpaint(retry.model, retry.params, retry.upscale, retry.highres); case 'txt2img': return this.txt2img(retry.model, retry.params, retry.upscale, retry.highres); case 'upscale': - return this.upscale(retry.model, retry.params, retry.upscale); + return this.upscale(retry.model, retry.params, retry.upscale, retry.highres); default: throw new InvalidArgumentError('unknown request type'); } diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index d6fd48f1..2fad1f02 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -73,6 +73,21 @@ export function ModelControl() { }); }} /> + { + setModel({ + pipeline, + }); + }} + /> + + - - - { - setModel({ - pipeline, - }); - }} - /> - { - setImg2Img({ - source: file, - }); - }} /> + { + setImg2Img({ + source: file, + }); + }} + /> s.img2img} onChange={setImg2Img} /> +