1
0
Fork 0

feat: add img2img loopback (#331)

This commit is contained in:
Sean Sube 2023-04-22 10:39:23 -05:00
parent 7b0095a665
commit 00fb64ba82
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 117 additions and 1 deletions

View File

@ -47,6 +47,70 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup
return loras, inversions return loras, inversions
def run_loopback(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
image: Image.Image,
progress: ProgressCallback,
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
) -> Image.Image:
if params.loopback == 0:
return image
# load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img"
pipe = load_pipeline(
server,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
)
def loopback_iteration(source: Image.Image):
if params.lpw():
logger.debug("using LPW pipeline for loopback")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=params.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
else:
logger.debug("using img2img pipeline for loopback")
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=params.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
for _i in range(params.loopback):
image = loopback_iteration(image)
return image
def run_highres( def run_highres(
job: WorkerContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
@ -58,7 +122,7 @@ def run_highres(
progress: ProgressCallback, progress: ProgressCallback,
inversions: List[Tuple[str, float]], inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]], loras: List[Tuple[str, float]],
) -> None: ) -> Image.Image:
if highres.scale <= 1: if highres.scale <= 1:
return image return image
@ -137,6 +201,7 @@ def run_highres(
) )
return result.images[0] return result.images[0]
else: else:
logger.debug("using img2img pipeline for highres")
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = highres_pipe( result = highres_pipe(
params.prompt, params.prompt,
@ -232,6 +297,15 @@ def run_txt2img_pipeline(
del pipe del pipe
for image, output in image_outputs: for image, output in image_outputs:
image = run_loopback(
job,
server,
params,
progress,
inversions,
loras,
)
image = run_highres( image = run_highres(
job, job,
server, server,

View File

@ -175,6 +175,7 @@ class ImageParams:
control: Optional[NetworkModel] control: Optional[NetworkModel]
input_prompt: str input_prompt: str
input_negative_prompt: str input_negative_prompt: str
loopback: int
def __init__( def __init__(
self, self,
@ -191,6 +192,7 @@ class ImageParams:
control: Optional[NetworkModel] = None, control: Optional[NetworkModel] = None,
input_prompt: Optional[str] = None, input_prompt: Optional[str] = None,
input_negative_prompt: Optional[str] = None, input_negative_prompt: Optional[str] = None,
loopback: int = 0,
) -> None: ) -> None:
self.model = model self.model = model
self.pipeline = pipeline self.pipeline = pipeline
@ -205,6 +207,7 @@ class ImageParams:
self.control = control self.control = control
self.input_prompt = input_prompt or prompt self.input_prompt = input_prompt or prompt
self.input_negative_prompt = input_negative_prompt or negative_prompt self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback
def lpw(self): def lpw(self):
return self.pipeline == "lpw" return self.pipeline == "lpw"
@ -224,6 +227,7 @@ class ImageParams:
"control": self.control.name if self.control is not None else "", "control": self.control.name if self.control is not None else "",
"input_prompt": self.input_prompt, "input_prompt": self.input_prompt,
"input_negative_prompt": self.input_negative_prompt, "input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
@ -241,6 +245,7 @@ class ImageParams:
kwargs.get("control", self.control), kwargs.get("control", self.control),
kwargs.get("input_prompt", self.input_prompt), kwargs.get("input_prompt", self.input_prompt),
kwargs.get("input_negative_prompt", self.input_negative_prompt), kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback),
) )

View File

@ -91,6 +91,13 @@ def pipeline_from_request(
get_config_value("eta", "max"), get_config_value("eta", "max"),
get_config_value("eta", "min"), get_config_value("eta", "min"),
) )
loopback = get_and_clamp_int(
request.args,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, request.args,
"steps", "steps",
@ -145,6 +152,7 @@ def pipeline_from_request(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
batch=batch, batch=batch,
control=control, control=control,
loopback=loopback,
) )
size = Size(width, height) size = Size(width, height)
return (device, params, size) return (device, params, size)

View File

@ -106,6 +106,12 @@
"max": 512, "max": 512,
"step": 8 "step": 8
}, },
"loopback": {
"default": 0,
"min": 0,
"max": 10,
"step": 1
},
"model": { "model": {
"default": "stable-diffusion-onnx-v1-5", "default": "stable-diffusion-onnx-v1-5",
"keys": [] "keys": []

View File

@ -104,6 +104,12 @@
"max": 512, "max": 512,
"step": 8 "step": 8
}, },
"loopback": {
"default": 0,
"min": 0,
"max": 10,
"step": 1
},
"model": { "model": {
"default": "stable-diffusion-onnx-v1-5", "default": "stable-diffusion-onnx-v1-5",
"keys": [] "keys": []

View File

@ -68,6 +68,7 @@ export interface Txt2ImgParams extends BaseImgParams {
export interface Img2ImgParams extends BaseImgParams { export interface Img2ImgParams extends BaseImgParams {
source: Blob; source: Blob;
loopback: number;
sourceFilter?: string; sourceFilter?: string;
strength: number; strength: number;
} }
@ -518,6 +519,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
const url = makeImageURL(root, 'img2img', params); const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model); appendModelToURL(url, model);
url.searchParams.append('loopback', params.loopback.toFixed(FIXED_INTEGER));
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));
if (doesExist(params.sourceFilter)) { if (doesExist(params.sourceFilter)) {

View File

@ -46,6 +46,7 @@ export function Img2Img() {
const source = useStore(state, (s) => s.img2img.source); const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter); const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength); const strength = useStore(state, (s) => s.img2img.strength);
const loopback = useStore(state, (s) => s.img2img.loopback);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img); const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
@ -112,6 +113,18 @@ export function Img2Img() {
}); });
}} }}
/> />
<NumericField
label={t('parameter.loopback')}
min={params.loopback.min}
max={params.loopback.max}
step={params.loopback.step}
value={loopback}
onChange={(value) => {
setImg2Img({
loopback: value,
});
}}
/>
</Stack> </Stack>
<HighresControl /> <HighresControl />
<UpscaleControl /> <UpscaleControl />

View File

@ -257,6 +257,7 @@ export function createStateSlices(server: ServerParams) {
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({ const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
img2img: { img2img: {
...base, ...base,
loopback: server.loopback.default,
source: null, source: null,
sourceFilter: '', sourceFilter: '',
strength: server.strength.default, strength: server.strength.default,
@ -273,6 +274,7 @@ export function createStateSlices(server: ServerParams) {
set({ set({
img2img: { img2img: {
...base, ...base,
loopback: server.loopback.default,
source: null, source: null,
sourceFilter: '', sourceFilter: '',
strength: server.strength.default, strength: server.strength.default,