1
0
Fork 0

fix region parsing

This commit is contained in:
Sean Sube 2023-11-05 16:25:48 -06:00
parent 8498252c75
commit baecb38343
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 9 deletions

View File

@ -284,6 +284,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
# 3. Encode input prompt
(
prompt_embeds,
@ -302,7 +304,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
)
# 3.b. Encode region prompts
regions = parse_regions(prompt)
region_embeds: List[
Tuple[
List[np.ndarray],
@ -314,7 +315,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
current_region_embeds = self._encode_prompt(
(
region_prompt_embeds,
region_negative_prompt_embeds,
region_pooled_prompt_embeds,
region_negative_pooled_prompt_embeds,
) = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
@ -326,16 +332,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
)
if do_classifier_free_guidance:
current_region_embeds[0] = np.concatenate(
(current_region_embeds[1], current_region_embeds[0]), axis=0
region_prompt_embeds = np.concatenate(
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
)
add_region_embeds.append(
np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0
(region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds), axis=0
)
)
region_embeds.append(current_region_embeds)
region_embeds.append(region_prompt_embeds)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
@ -441,7 +447,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
for i in range(len(regions)):
top, left, bottom, right, mode, prompt = regions[i]
print("running region prompt", top, left, bottom, right, mode, prompt)
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mode, prompt)
# convert coordinates to latent space
h_start = top // 8
@ -464,8 +470,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latent_model_input = latent_model_input.cpu().numpy()
# fetch region embeds
region_1 = region_embeds[i][0]
region_1 = region_embeds[i]
region_2 = add_region_embeds[i]
logger.debug("region embeds shape: %s, %s", region_1.shape, region_2.shape)
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)

View File

@ -458,5 +458,5 @@ def slice_prompt(prompt: str, slice: int) -> str:
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
def parse_regions(prompt: str) -> List[Region]:
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)