1
0
Fork 0

fix coords and threshold

This commit is contained in:
Sean Sube 2023-11-05 22:19:55 -06:00
parent 2de4eb92b2
commit 997891b255
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 13 deletions

View File

@ -304,14 +304,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
# 3.b. Encode region prompts # 3.b. Encode region prompts
region_embeds: List[ region_embeds: List[np.ndarray] = []
Tuple[
List[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
Optional[np.ndarray],
]
] = []
add_region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions: for _top, _left, _bottom, _right, _mode, region_prompt in regions:
@ -450,10 +443,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt) logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
# convert coordinates to latent space # convert coordinates to latent space
h_start = left // 8 h_start = top // 8
h_end = right // 8 h_end = bottom // 8
w_start = top // 8 w_start = left // 8
w_end = bottom // 8 w_end = right // 8
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
@ -504,7 +497,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult > 1000.0: if mult >= 1000.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] = mult count[:, :, h_start:h_end, w_start:w_end] = mult
else: else: