fix region parsing
This commit is contained in:
parent
8498252c75
commit
baecb38343
|
@ -284,6 +284,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
prompt, regions = parse_regions(prompt)
|
||||||
|
|
||||||
# 3. Encode input prompt
|
# 3. Encode input prompt
|
||||||
(
|
(
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
|
@ -302,7 +304,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3.b. Encode region prompts
|
# 3.b. Encode region prompts
|
||||||
regions = parse_regions(prompt)
|
|
||||||
region_embeds: List[
|
region_embeds: List[
|
||||||
Tuple[
|
Tuple[
|
||||||
List[np.ndarray],
|
List[np.ndarray],
|
||||||
|
@ -314,7 +315,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
add_region_embeds: List[np.ndarray] = []
|
add_region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||||
current_region_embeds = self._encode_prompt(
|
(
|
||||||
|
region_prompt_embeds,
|
||||||
|
region_negative_prompt_embeds,
|
||||||
|
region_pooled_prompt_embeds,
|
||||||
|
region_negative_pooled_prompt_embeds,
|
||||||
|
) = self._encode_prompt(
|
||||||
region_prompt,
|
region_prompt,
|
||||||
num_images_per_prompt,
|
num_images_per_prompt,
|
||||||
do_classifier_free_guidance,
|
do_classifier_free_guidance,
|
||||||
|
@ -326,16 +332,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
current_region_embeds[0] = np.concatenate(
|
region_prompt_embeds = np.concatenate(
|
||||||
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
|
||||||
)
|
)
|
||||||
add_region_embeds.append(
|
add_region_embeds.append(
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
(region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds), axis=0
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
region_embeds.append(current_region_embeds)
|
region_embeds.append(region_prompt_embeds)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
@ -441,7 +447,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
for i in range(len(regions)):
|
for i in range(len(regions)):
|
||||||
top, left, bottom, right, mode, prompt = regions[i]
|
top, left, bottom, right, mode, prompt = regions[i]
|
||||||
print("running region prompt", top, left, bottom, right, mode, prompt)
|
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mode, prompt)
|
||||||
|
|
||||||
# convert coordinates to latent space
|
# convert coordinates to latent space
|
||||||
h_start = top // 8
|
h_start = top // 8
|
||||||
|
@ -464,8 +470,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
# fetch region embeds
|
# fetch region embeds
|
||||||
region_1 = region_embeds[i][0]
|
region_1 = region_embeds[i]
|
||||||
region_2 = add_region_embeds[i]
|
region_2 = add_region_embeds[i]
|
||||||
|
logger.debug("region embeds shape: %s, %s", region_1.shape, region_2.shape)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
|
|
@ -458,5 +458,5 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
||||||
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
||||||
|
|
||||||
|
|
||||||
def parse_regions(prompt: str) -> List[Region]:
|
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||||
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)
|
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)
|
||||||
|
|
Loading…
Reference in New Issue