feat: add img2img loopback (#331)
This commit is contained in:
parent
7b0095a665
commit
00fb64ba82
|
@ -47,6 +47,70 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup
|
||||||
return loras, inversions
|
return loras, inversions
|
||||||
|
|
||||||
|
|
||||||
|
def run_loopback(
|
||||||
|
job: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
image: Image.Image,
|
||||||
|
progress: ProgressCallback,
|
||||||
|
inversions: List[Tuple[str, float]],
|
||||||
|
loras: List[Tuple[str, float]],
|
||||||
|
) -> Image.Image:
|
||||||
|
if params.loopback == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
# load img2img pipeline once
|
||||||
|
pipe_type = "lpw" if params.lpw() else "img2img"
|
||||||
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
|
pipe_type,
|
||||||
|
params.model,
|
||||||
|
params.scheduler,
|
||||||
|
job.get_device(),
|
||||||
|
inversions=inversions,
|
||||||
|
loras=loras,
|
||||||
|
)
|
||||||
|
|
||||||
|
def loopback_iteration(source: Image.Image):
|
||||||
|
if params.lpw():
|
||||||
|
logger.debug("using LPW pipeline for loopback")
|
||||||
|
rng = torch.manual_seed(params.seed)
|
||||||
|
result = pipe.img2img(
|
||||||
|
source,
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
strength=params.strength,
|
||||||
|
eta=params.eta,
|
||||||
|
callback=progress,
|
||||||
|
)
|
||||||
|
return result.images[0]
|
||||||
|
else:
|
||||||
|
logger.debug("using img2img pipeline for loopback")
|
||||||
|
rng = np.random.RandomState(params.seed)
|
||||||
|
result = pipe(
|
||||||
|
params.prompt,
|
||||||
|
source,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
strength=params.strength,
|
||||||
|
eta=params.eta,
|
||||||
|
callback=progress,
|
||||||
|
)
|
||||||
|
return result.images[0]
|
||||||
|
|
||||||
|
for _i in range(params.loopback):
|
||||||
|
image = loopback_iteration(image)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
def run_highres(
|
def run_highres(
|
||||||
job: WorkerContext,
|
job: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -58,7 +122,7 @@ def run_highres(
|
||||||
progress: ProgressCallback,
|
progress: ProgressCallback,
|
||||||
inversions: List[Tuple[str, float]],
|
inversions: List[Tuple[str, float]],
|
||||||
loras: List[Tuple[str, float]],
|
loras: List[Tuple[str, float]],
|
||||||
) -> None:
|
) -> Image.Image:
|
||||||
if highres.scale <= 1:
|
if highres.scale <= 1:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
@ -137,6 +201,7 @@ def run_highres(
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
else:
|
else:
|
||||||
|
logger.debug("using img2img pipeline for highres")
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = highres_pipe(
|
result = highres_pipe(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
@ -232,6 +297,15 @@ def run_txt2img_pipeline(
|
||||||
del pipe
|
del pipe
|
||||||
|
|
||||||
for image, output in image_outputs:
|
for image, output in image_outputs:
|
||||||
|
image = run_loopback(
|
||||||
|
job,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
progress,
|
||||||
|
inversions,
|
||||||
|
loras,
|
||||||
|
)
|
||||||
|
|
||||||
image = run_highres(
|
image = run_highres(
|
||||||
job,
|
job,
|
||||||
server,
|
server,
|
||||||
|
|
|
@ -175,6 +175,7 @@ class ImageParams:
|
||||||
control: Optional[NetworkModel]
|
control: Optional[NetworkModel]
|
||||||
input_prompt: str
|
input_prompt: str
|
||||||
input_negative_prompt: str
|
input_negative_prompt: str
|
||||||
|
loopback: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -191,6 +192,7 @@ class ImageParams:
|
||||||
control: Optional[NetworkModel] = None,
|
control: Optional[NetworkModel] = None,
|
||||||
input_prompt: Optional[str] = None,
|
input_prompt: Optional[str] = None,
|
||||||
input_negative_prompt: Optional[str] = None,
|
input_negative_prompt: Optional[str] = None,
|
||||||
|
loopback: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
@ -205,6 +207,7 @@ class ImageParams:
|
||||||
self.control = control
|
self.control = control
|
||||||
self.input_prompt = input_prompt or prompt
|
self.input_prompt = input_prompt or prompt
|
||||||
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
||||||
|
self.loopback = loopback
|
||||||
|
|
||||||
def lpw(self):
|
def lpw(self):
|
||||||
return self.pipeline == "lpw"
|
return self.pipeline == "lpw"
|
||||||
|
@ -224,6 +227,7 @@ class ImageParams:
|
||||||
"control": self.control.name if self.control is not None else "",
|
"control": self.control.name if self.control is not None else "",
|
||||||
"input_prompt": self.input_prompt,
|
"input_prompt": self.input_prompt,
|
||||||
"input_negative_prompt": self.input_negative_prompt,
|
"input_negative_prompt": self.input_negative_prompt,
|
||||||
|
"loopback": self.loopback,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -241,6 +245,7 @@ class ImageParams:
|
||||||
kwargs.get("control", self.control),
|
kwargs.get("control", self.control),
|
||||||
kwargs.get("input_prompt", self.input_prompt),
|
kwargs.get("input_prompt", self.input_prompt),
|
||||||
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
||||||
|
kwargs.get("loopback", self.loopback),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,13 @@ def pipeline_from_request(
|
||||||
get_config_value("eta", "max"),
|
get_config_value("eta", "max"),
|
||||||
get_config_value("eta", "min"),
|
get_config_value("eta", "min"),
|
||||||
)
|
)
|
||||||
|
loopback = get_and_clamp_int(
|
||||||
|
request.args,
|
||||||
|
"loopback",
|
||||||
|
get_config_value("loopback"),
|
||||||
|
get_config_value("loopback", "max"),
|
||||||
|
get_config_value("loopback", "min"),
|
||||||
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args,
|
request.args,
|
||||||
"steps",
|
"steps",
|
||||||
|
@ -145,6 +152,7 @@ def pipeline_from_request(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
control=control,
|
control=control,
|
||||||
|
loopback=loopback,
|
||||||
)
|
)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
|
|
@ -106,6 +106,12 @@
|
||||||
"max": 512,
|
"max": 512,
|
||||||
"step": 8
|
"step": 8
|
||||||
},
|
},
|
||||||
|
"loopback": {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 10,
|
||||||
|
"step": 1
|
||||||
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"default": "stable-diffusion-onnx-v1-5",
|
"default": "stable-diffusion-onnx-v1-5",
|
||||||
"keys": []
|
"keys": []
|
||||||
|
|
|
@ -104,6 +104,12 @@
|
||||||
"max": 512,
|
"max": 512,
|
||||||
"step": 8
|
"step": 8
|
||||||
},
|
},
|
||||||
|
"loopback": {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 10,
|
||||||
|
"step": 1
|
||||||
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"default": "stable-diffusion-onnx-v1-5",
|
"default": "stable-diffusion-onnx-v1-5",
|
||||||
"keys": []
|
"keys": []
|
||||||
|
|
|
@ -68,6 +68,7 @@ export interface Txt2ImgParams extends BaseImgParams {
|
||||||
export interface Img2ImgParams extends BaseImgParams {
|
export interface Img2ImgParams extends BaseImgParams {
|
||||||
source: Blob;
|
source: Blob;
|
||||||
|
|
||||||
|
loopback: number;
|
||||||
sourceFilter?: string;
|
sourceFilter?: string;
|
||||||
strength: number;
|
strength: number;
|
||||||
}
|
}
|
||||||
|
@ -518,6 +519,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
const url = makeImageURL(root, 'img2img', params);
|
const url = makeImageURL(root, 'img2img', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
url.searchParams.append('loopback', params.loopback.toFixed(FIXED_INTEGER));
|
||||||
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
|
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
|
||||||
|
|
||||||
if (doesExist(params.sourceFilter)) {
|
if (doesExist(params.sourceFilter)) {
|
||||||
|
|
|
@ -46,6 +46,7 @@ export function Img2Img() {
|
||||||
const source = useStore(state, (s) => s.img2img.source);
|
const source = useStore(state, (s) => s.img2img.source);
|
||||||
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
|
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
|
||||||
const strength = useStore(state, (s) => s.img2img.strength);
|
const strength = useStore(state, (s) => s.img2img.strength);
|
||||||
|
const loopback = useStore(state, (s) => s.img2img.loopback);
|
||||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
const setImg2Img = useStore(state, (s) => s.setImg2Img);
|
const setImg2Img = useStore(state, (s) => s.setImg2Img);
|
||||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||||
|
@ -112,6 +113,18 @@ export function Img2Img() {
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<NumericField
|
||||||
|
label={t('parameter.loopback')}
|
||||||
|
min={params.loopback.min}
|
||||||
|
max={params.loopback.max}
|
||||||
|
step={params.loopback.step}
|
||||||
|
value={loopback}
|
||||||
|
onChange={(value) => {
|
||||||
|
setImg2Img({
|
||||||
|
loopback: value,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
<HighresControl />
|
<HighresControl />
|
||||||
<UpscaleControl />
|
<UpscaleControl />
|
||||||
|
|
|
@ -257,6 +257,7 @@ export function createStateSlices(server: ServerParams) {
|
||||||
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
|
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
|
||||||
img2img: {
|
img2img: {
|
||||||
...base,
|
...base,
|
||||||
|
loopback: server.loopback.default,
|
||||||
source: null,
|
source: null,
|
||||||
sourceFilter: '',
|
sourceFilter: '',
|
||||||
strength: server.strength.default,
|
strength: server.strength.default,
|
||||||
|
@ -273,6 +274,7 @@ export function createStateSlices(server: ServerParams) {
|
||||||
set({
|
set({
|
||||||
img2img: {
|
img2img: {
|
||||||
...base,
|
...base,
|
||||||
|
loopback: server.loopback.default,
|
||||||
source: null,
|
source: null,
|
||||||
sourceFilter: '',
|
sourceFilter: '',
|
||||||
strength: server.strength.default,
|
strength: server.strength.default,
|
||||||
|
|
Loading…
Reference in New Issue