1
0
Fork 0

fix region parsing

This commit is contained in:
Sean Sube 2023-11-05 16:25:48 -06:00
parent 8498252c75
commit baecb38343
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 9 deletions

View File

@ -284,6 +284,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
# 3. Encode input prompt # 3. Encode input prompt
( (
prompt_embeds, prompt_embeds,
@ -302,7 +304,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
# 3.b. Encode region prompts # 3.b. Encode region prompts
regions = parse_regions(prompt)
region_embeds: List[ region_embeds: List[
Tuple[ Tuple[
List[np.ndarray], List[np.ndarray],
@ -314,7 +315,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
add_region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions: for _top, _left, _bottom, _right, _mode, region_prompt in regions:
current_region_embeds = self._encode_prompt( (
region_prompt_embeds,
region_negative_prompt_embeds,
region_pooled_prompt_embeds,
region_negative_pooled_prompt_embeds,
) = self._encode_prompt(
region_prompt, region_prompt,
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
@ -326,16 +332,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
current_region_embeds[0] = np.concatenate( region_prompt_embeds = np.concatenate(
(current_region_embeds[1], current_region_embeds[0]), axis=0 (region_negative_prompt_embeds, region_prompt_embeds), axis=0
) )
add_region_embeds.append( add_region_embeds.append(
np.concatenate( np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0 (region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds), axis=0
) )
) )
region_embeds.append(current_region_embeds) region_embeds.append(region_prompt_embeds)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@ -441,7 +447,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
for i in range(len(regions)): for i in range(len(regions)):
top, left, bottom, right, mode, prompt = regions[i] top, left, bottom, right, mode, prompt = regions[i]
print("running region prompt", top, left, bottom, right, mode, prompt) logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mode, prompt)
# convert coordinates to latent space # convert coordinates to latent space
h_start = top // 8 h_start = top // 8
@ -464,8 +470,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latent_model_input = latent_model_input.cpu().numpy() latent_model_input = latent_model_input.cpu().numpy()
# fetch region embeds # fetch region embeds
region_1 = region_embeds[i][0] region_1 = region_embeds[i]
region_2 = add_region_embeds[i] region_2 = add_region_embeds[i]
logger.debug("region embeds shape: %s, %s", region_1.shape, region_2.shape)
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)

View File

@ -458,5 +458,5 @@ def slice_prompt(prompt: str, slice: int) -> str:
Region = Tuple[int, int, int, int, Literal["add", "replace"], str] Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
def parse_regions(prompt: str) -> List[Region]: def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it) return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)