diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 75af59bd..2863d3f5 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -47,6 +47,70 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup return loras, inversions +def run_loopback( + job: WorkerContext, + server: ServerContext, + params: ImageParams, + image: Image.Image, + progress: ProgressCallback, + inversions: List[Tuple[str, float]], + loras: List[Tuple[str, float]], +) -> Image.Image: + if params.loopback == 0: + return image + + # load img2img pipeline once + pipe_type = "lpw" if params.lpw() else "img2img" + pipe = load_pipeline( + server, + pipe_type, + params.model, + params.scheduler, + job.get_device(), + inversions=inversions, + loras=loras, + ) + + def loopback_iteration(source: Image.Image): + if params.lpw(): + logger.debug("using LPW pipeline for loopback") + rng = torch.manual_seed(params.seed) + result = pipe.img2img( + source, + params.prompt, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=params.steps, + strength=params.strength, + eta=params.eta, + callback=progress, + ) + return result.images[0] + else: + logger.debug("using img2img pipeline for loopback") + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + source, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=params.steps, + strength=params.strength, + eta=params.eta, + callback=progress, + ) + return result.images[0] + + for _i in range(params.loopback): + image = loopback_iteration(image) + + return image + + def run_highres( job: WorkerContext, server: ServerContext, @@ -58,7 +122,7 @@ def run_highres( progress: ProgressCallback, inversions: List[Tuple[str, float]], loras: List[Tuple[str, float]], -) -> None: +) -> Image.Image: if highres.scale <= 1: return image @@ -137,6 +201,7 @@ def run_highres( ) return result.images[0] else: + logger.debug("using img2img pipeline for highres") rng = np.random.RandomState(params.seed) result = highres_pipe( params.prompt, @@ -232,6 +297,15 @@ def run_txt2img_pipeline( del pipe for image, output in image_outputs: + image = run_loopback( + job, + server, + params, + progress, + inversions, + loras, + ) + image = run_highres( job, server, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index eaaa774c..eca66a6c 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -175,6 +175,7 @@ class ImageParams: control: Optional[NetworkModel] input_prompt: str input_negative_prompt: str + loopback: int def __init__( self, @@ -191,6 +192,7 @@ class ImageParams: control: Optional[NetworkModel] = None, input_prompt: Optional[str] = None, input_negative_prompt: Optional[str] = None, + loopback: int = 0, ) -> None: self.model = model self.pipeline = pipeline @@ -205,6 +207,7 @@ class ImageParams: self.control = control self.input_prompt = input_prompt or prompt self.input_negative_prompt = input_negative_prompt or negative_prompt + self.loopback = loopback def lpw(self): return self.pipeline == "lpw" @@ -224,6 +227,7 @@ class ImageParams: "control": self.control.name if self.control is not None else "", "input_prompt": self.input_prompt, "input_negative_prompt": self.input_negative_prompt, + "loopback": self.loopback, } def with_args(self, **kwargs): @@ -241,6 +245,7 @@ class ImageParams: kwargs.get("control", self.control), kwargs.get("input_prompt", self.input_prompt), kwargs.get("input_negative_prompt", self.input_negative_prompt), + kwargs.get("loopback", self.loopback), ) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 6db5bb3f..ef1123f4 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -91,6 +91,13 @@ def pipeline_from_request( get_config_value("eta", "max"), get_config_value("eta", "min"), ) + loopback = get_and_clamp_int( + request.args, + "loopback", + get_config_value("loopback"), + get_config_value("loopback", "max"), + get_config_value("loopback", "min"), + ) steps = get_and_clamp_int( request.args, "steps", @@ -145,6 +152,7 @@ def pipeline_from_request( negative_prompt=negative_prompt, batch=batch, control=control, + loopback=loopback, ) size = Size(width, height) return (device, params, size) diff --git a/api/params.json b/api/params.json index 9212fb2a..6b591819 100644 --- a/api/params.json +++ b/api/params.json @@ -106,6 +106,12 @@ "max": 512, "step": 8 }, + "loopback": { + "default": 0, + "min": 0, + "max": 10, + "step": 1 + }, "model": { "default": "stable-diffusion-onnx-v1-5", "keys": [] diff --git a/gui/examples/config.json b/gui/examples/config.json index 5514d239..7cbd41b3 100644 --- a/gui/examples/config.json +++ b/gui/examples/config.json @@ -104,6 +104,12 @@ "max": 512, "step": 8 }, + "loopback": { + "default": 0, + "min": 0, + "max": 10, + "step": 1 + }, "model": { "default": "stable-diffusion-onnx-v1-5", "keys": [] diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index cebcbef1..03cb37f7 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -68,6 +68,7 @@ export interface Txt2ImgParams extends BaseImgParams { export interface Img2ImgParams extends BaseImgParams { source: Blob; + loopback: number; sourceFilter?: string; strength: number; } @@ -518,6 +519,7 @@ export function makeClient(root: string, f = fetch): ApiClient { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); + url.searchParams.append('loopback', params.loopback.toFixed(FIXED_INTEGER)); url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); if (doesExist(params.sourceFilter)) { diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index 939813f5..11a05169 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -46,6 +46,7 @@ export function Img2Img() { const source = useStore(state, (s) => s.img2img.source); const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter); const strength = useStore(state, (s) => s.img2img.strength); + const loopback = useStore(state, (s) => s.img2img.loopback); // eslint-disable-next-line @typescript-eslint/unbound-method const setImg2Img = useStore(state, (s) => s.setImg2Img); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -112,6 +113,18 @@ export function Img2Img() { }); }} /> + { + setImg2Img({ + loopback: value, + }); + }} + /> diff --git a/gui/src/state.ts b/gui/src/state.ts index f089e8b7..11f442b7 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -257,6 +257,7 @@ export function createStateSlices(server: ServerParams) { const createImg2ImgSlice: Slice = (set) => ({ img2img: { ...base, + loopback: server.loopback.default, source: null, sourceFilter: '', strength: server.strength.default, @@ -273,6 +274,7 @@ export function createStateSlices(server: ServerParams) { set({ img2img: { ...base, + loopback: server.loopback.default, source: null, sourceFilter: '', strength: server.strength.default,