feat(api): add img2img mode to panorama pipeline
This commit is contained in:
parent
29c616d99e
commit
47d80b07b3
|
@ -15,13 +15,14 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import PIL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import deprecate, logging
|
from diffusers.utils import deprecate, logging, PIL_INTERPOLATION
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
@ -30,6 +31,27 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(image):
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
return image
|
||||||
|
elif isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
w, h = image[0].size
|
||||||
|
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
|
||||||
|
|
||||||
|
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = 2.0 * image - 1.0
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
vae_encoder: OnnxRuntimeModel
|
vae_encoder: OnnxRuntimeModel
|
||||||
vae_decoder: OnnxRuntimeModel
|
vae_decoder: OnnxRuntimeModel
|
||||||
|
@ -491,3 +513,236 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
return (image, has_nsfw_concept)
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
|
||||||
|
def img2img(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||||
|
strength: float = 0.8,
|
||||||
|
num_inference_steps: Optional[int] = 50,
|
||||||
|
guidance_scale: Optional[float] = 7.5,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
eta: Optional[float] = 0.0,
|
||||||
|
generator: Optional[np.random.RandomState] = None,
|
||||||
|
prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
|
||||||
|
`Image`, or tensor representing an image batch which will be upscaled. *
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||||
|
is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`np.random.RandomState`, *optional*):
|
||||||
|
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
|
||||||
|
latents (`np.ndarray`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`np.ndarray`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`np.ndarray`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (`Callable`, *optional*):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
|
(nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
height = image.height
|
||||||
|
width = image.width
|
||||||
|
|
||||||
|
# check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
# define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if generator is None:
|
||||||
|
generator = np.random
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
prompt_embeds = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the initial random noise unless the user supplied it
|
||||||
|
latents_dtype = prompt_embeds.dtype
|
||||||
|
|
||||||
|
# set timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
# prep image
|
||||||
|
image = preprocess(image).cpu().numpy()
|
||||||
|
image = image.astype(latents_dtype)
|
||||||
|
# encode the init image into latents and scale the latents
|
||||||
|
latents = self.vae_encoder(sample=image)[0]
|
||||||
|
latents = 0.18215 * latents
|
||||||
|
|
||||||
|
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||||
|
|
||||||
|
# get the original timestep using init_timestep
|
||||||
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
init_timestep = int(num_inference_steps * strength) + offset
|
||||||
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
|
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
|
||||||
|
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
|
||||||
|
|
||||||
|
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
|
||||||
|
init_latents = self.scheduler.add_noise(
|
||||||
|
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
|
||||||
|
)
|
||||||
|
init_latents = init_latents.numpy()
|
||||||
|
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
timestep_dtype = next(
|
||||||
|
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||||
|
)
|
||||||
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
|
# panorama additions
|
||||||
|
views = self.get_views(height, width)
|
||||||
|
count = np.zeros_like(latents)
|
||||||
|
value = np.zeros_like(latents)
|
||||||
|
|
||||||
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
|
count.fill(0)
|
||||||
|
value.fill(0)
|
||||||
|
|
||||||
|
for h_start, h_end, w_start, w_end in views:
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||||
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||||
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
|
||||||
|
)
|
||||||
|
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
latents = 1 / 0.18215 * latents
|
||||||
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
image = np.concatenate(
|
||||||
|
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||||
|
image = image.transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
safety_checker_input = self.feature_extractor(
|
||||||
|
self.numpy_to_pil(image), return_tensors="np"
|
||||||
|
).pixel_values.astype(image.dtype)
|
||||||
|
|
||||||
|
images, has_nsfw_concept = [], []
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||||
|
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||||
|
)
|
||||||
|
images.append(image_i)
|
||||||
|
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||||
|
image = np.concatenate(images)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
|
|
@ -47,7 +47,9 @@ def run_loopback(
|
||||||
loopback_progress = ChainProgress.from_progress(progress)
|
loopback_progress = ChainProgress.from_progress(progress)
|
||||||
|
|
||||||
# load img2img pipeline once
|
# load img2img pipeline once
|
||||||
pipe_type = "lpw" if params.lpw() else "img2img"
|
pipe_type = params.get_valid_pipeline("img2img")
|
||||||
|
logger.debug("using %s pipeline for loopback", pipe_type)
|
||||||
|
|
||||||
pipe = pipeline or load_pipeline(
|
pipe = pipeline or load_pipeline(
|
||||||
server,
|
server,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
|
@ -59,8 +61,7 @@ def run_loopback(
|
||||||
)
|
)
|
||||||
|
|
||||||
def loopback_iteration(source: Image.Image):
|
def loopback_iteration(source: Image.Image):
|
||||||
if params.lpw():
|
if pipe_type in ["lpw", "panorama"]:
|
||||||
logger.debug("using LPW pipeline for loopback")
|
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = pipe.img2img(
|
result = pipe.img2img(
|
||||||
source,
|
source,
|
||||||
|
@ -76,7 +77,6 @@ def run_loopback(
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
else:
|
else:
|
||||||
logger.debug("using img2img pipeline for loopback")
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = pipe(
|
result = pipe(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
@ -134,7 +134,9 @@ def run_highres(
|
||||||
)
|
)
|
||||||
|
|
||||||
# load img2img pipeline once
|
# load img2img pipeline once
|
||||||
pipe_type = "lpw" if params.lpw() else "img2img"
|
pipe_type = params.get_valid_pipeline("img2img")
|
||||||
|
logger.debug("using %s pipeline for highres", pipe_type)
|
||||||
|
|
||||||
highres_pipe = pipeline or load_pipeline(
|
highres_pipe = pipeline or load_pipeline(
|
||||||
server,
|
server,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
|
@ -172,8 +174,7 @@ def run_highres(
|
||||||
callback=highres_progress,
|
callback=highres_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.lpw():
|
if pipe_type in ["lpw", "panorama"]:
|
||||||
logger.debug("using LPW pipeline for highres")
|
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = highres_pipe.img2img(
|
result = highres_pipe.img2img(
|
||||||
tile,
|
tile,
|
||||||
|
@ -189,7 +190,6 @@ def run_highres(
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
else:
|
else:
|
||||||
logger.debug("using img2img pipeline for highres")
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
result = highres_pipe(
|
result = highres_pipe(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
@ -236,7 +236,9 @@ def run_txt2img_pipeline(
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
prompt_pairs, loras, inversions = parse_prompt(params)
|
prompt_pairs, loras, inversions = parse_prompt(params)
|
||||||
|
|
||||||
pipe_type = params.pipeline # TODO: allow txt2img, panorama, etc, but filter out the others
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
|
logger.debug("using %s pipeline for txt2img", pipe_type)
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
|
@ -248,8 +250,7 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
|
|
||||||
if params.lpw():
|
if pipe_type in ["lpw", "panorama"]:
|
||||||
logger.debug("using LPW pipeline for txt2img")
|
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = pipe.text2img(
|
result = pipe.text2img(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
@ -345,9 +346,10 @@ def run_img2img_pipeline(
|
||||||
logger.debug("running source filter: %s", f.__name__)
|
logger.debug("running source filter: %s", f.__name__)
|
||||||
source = f(server, source)
|
source = f(server, source)
|
||||||
|
|
||||||
|
pipe_type = params.get_valid_pipeline("img2img")
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
params.pipeline, # this is one of the only places this can actually vary between different pipelines
|
pipe_type,
|
||||||
params.model,
|
params.model,
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
|
@ -357,15 +359,17 @@ def run_img2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe_params = {}
|
pipe_params = {}
|
||||||
if params.pipeline == "controlnet":
|
if pipe_type == "controlnet":
|
||||||
pipe_params["controlnet_conditioning_scale"] = strength
|
pipe_params["controlnet_conditioning_scale"] = strength
|
||||||
elif params.pipeline == "img2img":
|
elif pipe_type == "img2img":
|
||||||
pipe_params["strength"] = strength
|
pipe_params["strength"] = strength
|
||||||
elif params.pipeline == "pix2pix":
|
elif pipe_type == "panorama":
|
||||||
|
pipe_params["strength"] = strength
|
||||||
|
elif pipe_type == "pix2pix":
|
||||||
pipe_params["image_guidance_scale"] = strength
|
pipe_params["image_guidance_scale"] = strength
|
||||||
|
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
if params.lpw():
|
if pipe_type in ["lpw", "panorama"]:
|
||||||
logger.debug("using LPW pipeline for img2img")
|
logger.debug("using LPW pipeline for img2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
rng = torch.manual_seed(params.seed)
|
||||||
result = pipe.img2img(
|
result = pipe.img2img(
|
||||||
|
|
|
@ -220,6 +220,27 @@ class ImageParams:
|
||||||
def do_cfg(self):
|
def do_cfg(self):
|
||||||
return self.cfg > 1.0
|
return self.cfg > 1.0
|
||||||
|
|
||||||
|
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
|
||||||
|
pipeline = pipeline or self.pipeline
|
||||||
|
|
||||||
|
# if the correct pipeline was already requested, simply use that
|
||||||
|
if group == pipeline:
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
# otherwise, check for additional allowed pipelines
|
||||||
|
if group == "img2img":
|
||||||
|
if pipeline in ["controlnet", "lpw", "panorama", "pix2pix"]:
|
||||||
|
return pipeline
|
||||||
|
elif group == "inpaint":
|
||||||
|
if pipeline in ["controlnet"]:
|
||||||
|
return pipeline
|
||||||
|
elif group == "txt2img":
|
||||||
|
if pipeline in ["panorama"]:
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
||||||
|
return group
|
||||||
|
|
||||||
def lpw(self):
|
def lpw(self):
|
||||||
return self.pipeline == "lpw"
|
return self.pipeline == "lpw"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue