fix(api): switch pipeline ctor based on VAE presence, improve panorama logging
This commit is contained in:
parent
ae34e466ef
commit
92311281df
|
@ -229,10 +229,9 @@ def load_pipeline(
|
||||||
tokenizer_2=components.get("tokenizer_2", None),
|
tokenizer_2=components.get("tokenizer_2", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("assembling SD pipeline for %s", pipeline_class.__name__)
|
if "vae" in components:
|
||||||
|
|
||||||
if pipeline_class == OnnxStableDiffusionUpscalePipeline:
|
|
||||||
# upscale uses a single VAE
|
# upscale uses a single VAE
|
||||||
|
logger.debug("assembling SD pipeline for %s with single VAE", pipeline_class.__name__)
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
components["vae"],
|
components["vae"],
|
||||||
components["text_encoder"],
|
components["text_encoder"],
|
||||||
|
@ -242,6 +241,7 @@ def load_pipeline(
|
||||||
scheduler,
|
scheduler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
logger.debug("assembling SD pipeline for %s with VAE codec", pipeline_class.__name__)
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
components["vae_encoder"],
|
components["vae_encoder"],
|
||||||
components["vae_decoder"],
|
components["vae_decoder"],
|
||||||
|
|
|
@ -563,6 +563,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
count = np.zeros(resize_latent_shape(latents, resize))
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
value = np.zeros(resize_latent_shape(latents, resize))
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
@ -977,6 +979,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
count = np.zeros(resize_latent_shape(latents, resize))
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
value = np.zeros(resize_latent_shape(latents, resize))
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
@ -1298,6 +1302,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# panorama additions
|
# panorama additions
|
||||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
count = np.zeros(resize_latent_shape(latents, resize))
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
value = np.zeros(resize_latent_shape(latents, resize))
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
|
|
@ -394,6 +394,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
# 8. Panorama additions
|
# 8. Panorama additions
|
||||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
count = np.zeros(resize_latent_shape(latents, resize))
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
value = np.zeros(resize_latent_shape(latents, resize))
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
@ -819,6 +821,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
|
|
||||||
# 8. Panorama additions
|
# 8. Panorama additions
|
||||||
views, resize = self.get_views(height, width, self.window, self.stride)
|
views, resize = self.get_views(height, width, self.window, self.stride)
|
||||||
|
logger.trace("panorama resized latents to %s", resize)
|
||||||
|
|
||||||
count = np.zeros(resize_latent_shape(latents, resize))
|
count = np.zeros(resize_latent_shape(latents, resize))
|
||||||
value = np.zeros(resize_latent_shape(latents, resize))
|
value = np.zeros(resize_latent_shape(latents, resize))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue