replace previous latents when region multiplier passes threshold
This commit is contained in:
parent
05f63a32b7
commit
2de4eb92b2
|
@ -250,12 +250,12 @@ class ChainPipeline:
|
||||||
stage_sources = stage_outputs
|
stage_sources = stage_outputs
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
|
worker.retries = worker.retries - 1
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error while running stage pipeline, %s retries left", worker.retries
|
"error while running stage pipeline, %s retries left", worker.retries
|
||||||
)
|
)
|
||||||
server.cache.clear()
|
server.cache.clear()
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
worker.retries = worker.retries - 1
|
|
||||||
|
|
||||||
if worker.retries <= 0:
|
if worker.retries <= 0:
|
||||||
raise RetryException("exhausted retries on stage")
|
raise RetryException("exhausted retries on stage")
|
||||||
|
|
|
@ -262,7 +262,11 @@ def load_pipeline(
|
||||||
|
|
||||||
for vae in VAE_COMPONENTS:
|
for vae in VAE_COMPONENTS:
|
||||||
if hasattr(pipe, vae):
|
if hasattr(pipe, vae):
|
||||||
getattr(pipe, vae).set_tiled(tiled=params.tiled_vae)
|
vae_model = getattr(pipe, vae)
|
||||||
|
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||||
|
vae_model.set_window_size(
|
||||||
|
params.vae_tile // 8, params.vae_overlap
|
||||||
|
)
|
||||||
|
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
|
@ -276,12 +280,6 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
pipe.set_window_size(params.unet_tile // 8, unet_stride)
|
||||||
|
|
||||||
for vae in VAE_COMPONENTS:
|
|
||||||
if hasattr(pipe, vae):
|
|
||||||
getattr(pipe, vae).set_window_size(
|
|
||||||
params.vae_tile // 8, params.vae_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return pipe
|
return pipe
|
||||||
|
|
|
@ -445,8 +445,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
for i in range(len(regions)):
|
for r in range(len(regions)):
|
||||||
top, left, bottom, right, mult, prompt = regions[i]
|
top, left, bottom, right, mult, prompt = regions[r]
|
||||||
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
|
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
|
||||||
|
|
||||||
# convert coordinates to latent space
|
# convert coordinates to latent space
|
||||||
|
@ -475,8 +475,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
sample=latent_region_input,
|
sample=latent_region_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=region_embeds[i],
|
encoder_hidden_states=region_embeds[r],
|
||||||
text_embeds=add_region_embeds[i],
|
text_embeds=add_region_embeds[r],
|
||||||
time_ids=add_time_ids,
|
time_ids=add_time_ids,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred[0]
|
noise_pred = noise_pred[0]
|
||||||
|
@ -504,8 +504,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
|
if mult > 1000.0:
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] = mult
|
||||||
|
else:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += mult
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
Loading…
Reference in New Issue