fix region parsing
This commit is contained in:
parent
8498252c75
commit
baecb38343
|
@ -284,6 +284,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt, regions = parse_regions(prompt)
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
|
@ -302,7 +304,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
)
|
||||
|
||||
# 3.b. Encode region prompts
|
||||
regions = parse_regions(prompt)
|
||||
region_embeds: List[
|
||||
Tuple[
|
||||
List[np.ndarray],
|
||||
|
@ -314,7 +315,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
add_region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||
current_region_embeds = self._encode_prompt(
|
||||
(
|
||||
region_prompt_embeds,
|
||||
region_negative_prompt_embeds,
|
||||
region_pooled_prompt_embeds,
|
||||
region_negative_pooled_prompt_embeds,
|
||||
) = self._encode_prompt(
|
||||
region_prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
|
@ -326,16 +332,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
current_region_embeds[0] = np.concatenate(
|
||||
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
||||
region_prompt_embeds = np.concatenate(
|
||||
(region_negative_prompt_embeds, region_prompt_embeds), axis=0
|
||||
)
|
||||
add_region_embeds.append(
|
||||
np.concatenate(
|
||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
||||
(region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds), axis=0
|
||||
)
|
||||
)
|
||||
|
||||
region_embeds.append(current_region_embeds)
|
||||
region_embeds.append(region_prompt_embeds)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
@ -441,7 +447,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
|
||||
for i in range(len(regions)):
|
||||
top, left, bottom, right, mode, prompt = regions[i]
|
||||
print("running region prompt", top, left, bottom, right, mode, prompt)
|
||||
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mode, prompt)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // 8
|
||||
|
@ -464,8 +470,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# fetch region embeds
|
||||
region_1 = region_embeds[i][0]
|
||||
region_1 = region_embeds[i]
|
||||
region_2 = add_region_embeds[i]
|
||||
logger.debug("region embeds shape: %s, %s", region_1.shape, region_2.shape)
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
|
|
|
@ -458,5 +458,5 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
|||
Region = Tuple[int, int, int, int, Literal["add", "replace"], str]
|
||||
|
||||
|
||||
def parse_regions(prompt: str) -> List[Region]:
|
||||
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||
return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)
|
||||
|
|
Loading…
Reference in New Issue