diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 4e055105..f4974c42 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -36,7 +36,9 @@ class BlendImg2ImgStage(BaseStage): "blending image using img2img, %s steps: %s", params.steps, params.prompt ) - prompt_pairs, loras, inversions = parse_prompt(params) + prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( + params + ) pipe_type = params.get_valid_pipeline("img2img") pipe = load_pipeline( @@ -67,10 +69,10 @@ class BlendImg2ImgStage(BaseStage): rng = torch.manual_seed(params.seed) result = pipe.img2img( source, - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, callback=callback, **pipe_params, @@ -84,11 +86,11 @@ class BlendImg2ImgStage(BaseStage): rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, image=source, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, callback=callback, **pipe_params, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 6eef0f11..8c15d4cf 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -6,7 +6,7 @@ import torch from PIL import Image from ..diffusers.load import load_pipeline -from ..diffusers.utils import get_latents_from_seed, parse_prompt +from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams @@ -43,7 +43,9 @@ class BlendInpaintStage(BaseStage): "blending image using inpaint, %s steps: %s", params.steps, params.prompt ) - _prompt_pairs, loras, inversions = parse_prompt(params) + prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( + params + ) pipe_type = params.get_valid_pipeline("inpaint") pipe = load_pipeline( server, @@ -88,30 +90,36 @@ class BlendInpaintStage(BaseStage): logger.debug("using LPW pipeline for inpaint") rng = torch.manual_seed(params.seed) result = pipe.inpaint( - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, height=size.height, image=tile_source, latents=latents, mask_image=tile_mask, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, width=size.width, eta=params.eta, callback=callback, ) else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) + rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, height=size.height, image=tile_source, latents=latents, mask_image=stage_mask, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, width=size.width, eta=params.eta, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index ebc8250f..3042a114 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -41,7 +41,9 @@ class SourceTxt2ImgStage(BaseStage): "a source image was passed to a txt2img stage, and will be discarded" ) - prompt_pairs, loras, inversions = parse_prompt(params) + prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( + params + ) latents = get_latents_from_seed(params.seed, size, params.batch) pipe_type = params.get_valid_pipeline("txt2img") @@ -58,13 +60,13 @@ class SourceTxt2ImgStage(BaseStage): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) result = pipe.text2img( - params.prompt, + prompt, height=size.height, width=size.width, generator=rng, guidance_scale=params.cfg, latents=latents, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_images_per_prompt=params.batch, num_inference_steps=params.steps, eta=params.eta, @@ -79,13 +81,13 @@ class SourceTxt2ImgStage(BaseStage): rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, height=size.height, width=size.width, generator=rng, guidance_scale=params.cfg, latents=latents, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_images_per_prompt=params.batch, num_inference_steps=params.steps, eta=params.eta, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 9276cfbc..a9afda85 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -6,7 +6,12 @@ import torch from PIL import Image, ImageDraw, ImageOps from ..diffusers.load import load_pipeline -from ..diffusers.utils import get_latents_from_seed, get_tile_latents, parse_prompt +from ..diffusers.utils import ( + encode_prompt, + get_latents_from_seed, + get_tile_latents, + parse_prompt, +) from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams @@ -39,7 +44,9 @@ class UpscaleOutpaintStage(BaseStage): callback: Optional[ProgressCallback] = None, **kwargs, ) -> List[Image.Image]: - _prompt_pairs, loras, inversions = parse_prompt(params) + prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( + params + ) pipe_type = params.get_valid_pipeline("inpaint", params.pipeline) pipe = load_pipeline( @@ -108,27 +115,33 @@ class UpscaleOutpaintStage(BaseStage): result = pipe.inpaint( tile_source, tile_mask, - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, height=size.height, latents=latents, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, width=size.width, callback=callback, ) else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) + rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, tile_source, tile_mask, height=size.height, width=size.width, num_inference_steps=params.steps, guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, generator=rng, latents=latents, callback=callback, diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index f86e2a5c..b2fd78dc 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -35,7 +35,9 @@ class UpscaleStableDiffusionStage(BaseStage): "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt ) - prompt_pairs, _loras, _inversions = parse_prompt(params) + prompt_pairs, _loras, _inversions, (prompt, negative_prompt) = parse_prompt( + params + ) pipeline = load_pipeline( server, @@ -57,11 +59,11 @@ class UpscaleStableDiffusionStage(BaseStage): outputs = [] for source in sources: result = pipeline( - params.prompt, + prompt, source, generator=generator, guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, + negative_prompt=negative_prompt, num_inference_steps=params.steps, eta=params.eta, noise_level=upscale.denoise, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 72c2e3a8..abb3dec4 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -82,7 +82,7 @@ def run_txt2img_pipeline( progress = job.get_progress_callback() images = chain(job, server, params, [], callback=progress) - _prompt_pairs, loras, inversions = parse_prompt(params, use_input=True) + _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): dest = save_image( @@ -178,7 +178,7 @@ def run_img2img_pipeline( images.append(source) # save with metadata - _prompt_pairs, loras, inversions = parse_prompt(params, use_input=True) + _pairs, loras, inversions, _rest = parse_prompt(params) size = Size(*source.size) for image, output in zip(images, outputs): @@ -264,7 +264,7 @@ def run_inpaint_pipeline( progress = job.get_progress_callback() images = chain(job, server, params, [source], callback=progress) - _prompt_pairs, loras, inversions = parse_prompt(params, use_input=True) + _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): dest = save_image( server, @@ -332,7 +332,7 @@ def run_upscale_pipeline( progress = job.get_progress_callback() images = chain(job, server, params, [source], callback=progress) - _prompt_pairs, loras, inversions = parse_prompt(params, use_input=True) + _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): dest = save_image( server, diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 615c7839..d80b138e 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -312,12 +312,17 @@ def get_scaled_latents( def parse_prompt( params: ImageParams, use_input: bool = False, -) -> Tuple[List[Tuple[str, str]], List[Tuple[str, float]], List[Tuple[str, float]]]: +) -> Tuple[ + List[Tuple[str, str]], + List[Tuple[str, float]], + List[Tuple[str, float]], + Tuple[str, str], +]: prompt, loras = get_loras_from_prompt( params.input_prompt if use_input else params.prompt ) prompt, inversions = get_inversions_from_prompt(prompt) - params.prompt = prompt + # params.prompt = prompt neg_prompt = None if params.input_negative_prompt is not None: @@ -325,9 +330,8 @@ def parse_prompt( params.input_negative_prompt if use_input else params.negative_prompt ) neg_prompt, neg_inversions = get_inversions_from_prompt(neg_prompt) - params.negative_prompt = neg_prompt + # params.negative_prompt = neg_prompt - # TODO: check whether these need to be * -1 loras.extend(neg_loras) inversions.extend(neg_inversions) @@ -352,7 +356,7 @@ def parse_prompt( for i in range(neg_prompt_count, prompt_count): neg_prompts.append(neg_prompts[i % neg_prompt_count]) - return list(zip(prompts, neg_prompts)), loras, inversions + return list(zip(prompts, neg_prompts)), loras, inversions, (prompt, neg_prompt) def encode_prompt( @@ -372,7 +376,7 @@ def encode_prompt( ] -def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str: +def parse_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) -> str: next_match = WILDCARD_TOKEN.search(prompt) remaining_prompt = prompt @@ -400,6 +404,14 @@ def replace_wildcards(prompt: str, seed: int, wildcards: Dict[str, List[str]]) - return remaining_prompt +def replace_wildcards(params: ImageParams, wildcards: Dict[str, List[str]]): + params.prompt = parse_wildcards(params.prompt, params.seed, wildcards) + if params.negative_prompt is not None: + params.negative_prompt = parse_wildcards( + params.negative_prompt, params.seed, wildcards + ) + + def pop_random(list: List[str]) -> str: """ From https://stackoverflow.com/a/14088129 diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index aa9d5a22..3132b461 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -179,11 +179,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) - params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data()) - if params.negative_prompt is not None: - params.negative_prompt = replace_wildcards( - params.negative_prompt, params.seed, get_wildcard_data() - ) + replace_wildcards(params, get_wildcard_data()) output_count = params.batch if source_filter is not None and source_filter != "none": @@ -221,11 +217,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): upscale = upscale_from_request() highres = highres_from_request() - params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data()) - if params.negative_prompt is not None: - params.negative_prompt = replace_wildcards( - params.negative_prompt, params.seed, get_wildcard_data() - ) + replace_wildcards(params, get_wildcard_data()) output = make_output_name(server, "txt2img", params, size) @@ -271,11 +263,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] ) - params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data()) - if params.negative_prompt is not None: - params.negative_prompt = replace_wildcards( - params.negative_prompt, params.seed, get_wildcard_data() - ) + replace_wildcards(params, get_wildcard_data()) output = make_output_name( server, @@ -334,11 +322,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): upscale = upscale_from_request() highres = highres_from_request() - params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data()) - if params.negative_prompt is not None: - params.negative_prompt = replace_wildcards( - params.negative_prompt, params.seed, get_wildcard_data() - ) + replace_wildcards(params, get_wildcard_data()) output = make_output_name(server, "upscale", params, size) @@ -380,11 +364,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): output = make_output_name(server, "chain", params, size) job_name = output[0] - params.prompt = replace_wildcards(params.prompt, params.seed, get_wildcard_data()) - if params.negative_prompt is not None: - params.negative_prompt = replace_wildcards( - params.negative_prompt, params.seed, get_wildcard_data() - ) + replace_wildcards(params, get_wildcard_data()) pipeline = ChainPipeline() for stage_data in data.get("stages", []):