1
0
Fork 0

replace previous latents when region multiplier passes threshold

This commit is contained in:
Sean Sube 2023-11-05 21:41:40 -06:00
parent 05f63a32b7
commit 2de4eb92b2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 16 additions and 14 deletions

View File

@ -250,12 +250,12 @@ class ChainPipeline:
stage_sources = stage_outputs stage_sources = stage_outputs
break break
except Exception: except Exception:
worker.retries = worker.retries - 1
logger.exception( logger.exception(
"error while running stage pipeline, %s retries left", worker.retries "error while running stage pipeline, %s retries left", worker.retries
) )
server.cache.clear() server.cache.clear()
run_gc([worker.get_device()]) run_gc([worker.get_device()])
worker.retries = worker.retries - 1
if worker.retries <= 0: if worker.retries <= 0:
raise RetryException("exhausted retries on stage") raise RetryException("exhausted retries on stage")

View File

@ -262,7 +262,11 @@ def load_pipeline(
for vae in VAE_COMPONENTS: for vae in VAE_COMPONENTS:
if hasattr(pipe, vae): if hasattr(pipe, vae):
getattr(pipe, vae).set_tiled(tiled=params.tiled_vae) vae_model = getattr(pipe, vae)
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(
params.vae_tile // 8, params.vae_overlap
)
# update panorama params # update panorama params
if params.is_panorama(): if params.is_panorama():
@ -276,12 +280,6 @@ def load_pipeline(
) )
pipe.set_window_size(params.unet_tile // 8, unet_stride) pipe.set_window_size(params.unet_tile // 8, unet_stride)
for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
getattr(pipe, vae).set_window_size(
params.vae_tile // 8, params.vae_overlap
)
run_gc([device]) run_gc([device])
return pipe return pipe

View File

@ -445,8 +445,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for i in range(len(regions)): for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[i] top, left, bottom, right, mult, prompt = regions[r]
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt) logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
# convert coordinates to latent space # convert coordinates to latent space
@ -475,8 +475,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
noise_pred = self.unet( noise_pred = self.unet(
sample=latent_region_input, sample=latent_region_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=region_embeds[i], encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[i], text_embeds=add_region_embeds[r],
time_ids=add_time_ids, time_ids=add_time_ids,
) )
noise_pred = noise_pred[0] noise_pred = noise_pred[0]
@ -504,6 +504,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult > 1000.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] = mult
else:
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] += mult count[:, :, h_start:h_end, w_start:w_end] += mult