1
0
Fork 0

fix(api): convert back to model format after blending, convert samples as needed (#274)

This commit is contained in:
Sean Sube 2023-03-21 22:05:14 -05:00
parent 0315a8cbc6
commit c8aad8554e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 22 additions and 28 deletions

View File

@ -97,8 +97,8 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
torch_dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(ctx.model_path, name)
model_index = path.join(dest_path, "model_index.json")
@ -117,7 +117,7 @@ def convert_diffusion_diffusers(
pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=torch_dtype,
torch_dtype=dtype,
use_auth_token=ctx.token,
).to(ctx.training_device)
output_path = Path(dest_path)
@ -174,11 +174,11 @@ def convert_diffusion_diffusers(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
torch.randn(2).to(device=ctx.training_device, dtype=torch_dtype),
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
unet_scale,
),
@ -230,7 +230,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=torch_dtype),
).to(device=ctx.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / "model.onnx",
@ -255,7 +255,7 @@ def convert_diffusion_diffusers(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
False,
),
@ -279,7 +279,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=torch_dtype),
).to(device=ctx.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",

View File

@ -62,7 +62,7 @@ def blend_loras(
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.torch_dtype()
dtype = torch.float32
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]

View File

@ -22,7 +22,7 @@ def blend_textual_inversions(
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.numpy_dtype()
dtype = np.float32
embeds = {}
for name, weight, base_token, inversion_format in inversions:
@ -131,11 +131,11 @@ def blend_textual_inversions(
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
embedding_weights = numpy_helper.to_array(embedding_node)
base_weights = numpy_helper.to_array(embedding_node)
weights_dim = embedding_weights.shape[1]
weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
@ -149,7 +149,7 @@ def blend_textual_inversions(
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(dtype), embedding_node.name
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]

View File

@ -360,15 +360,15 @@ class UNetWrapper(object):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if "onnx-fp16" in self.server.optimizations:
logger.info("converting UNet sample to ONNX fp16")
sample = sample.astype(np.float16)
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
elif sample.dtype != timestep.dtype:
logger.info("converting UNet sample to timestep dtype")
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if sample.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
return self.wrapped(
sample=sample,
timestep=timestep,

View File

@ -86,9 +86,3 @@ class ServerContext:
return torch.float16
else:
return torch.float32
def numpy_dtype(self):
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
return np.float16
else:
return np.float32