fix(api): convert back to model format after blending, convert samples as needed (#274)
This commit is contained in:
parent
0315a8cbc6
commit
c8aad8554e
|
@ -97,8 +97,8 @@ def convert_diffusion_diffusers(
|
||||||
single_vae = model.get("single_vae")
|
single_vae = model.get("single_vae")
|
||||||
replace_vae = model.get("vae")
|
replace_vae = model.get("vae")
|
||||||
|
|
||||||
torch_dtype = ctx.torch_dtype()
|
dtype = ctx.torch_dtype()
|
||||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||||
|
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
model_index = path.join(dest_path, "model_index.json")
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
|
@ -117,7 +117,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
source,
|
source,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=ctx.token,
|
use_auth_token=ctx.token,
|
||||||
).to(ctx.training_device)
|
).to(ctx.training_device)
|
||||||
output_path = Path(dest_path)
|
output_path = Path(dest_path)
|
||||||
|
@ -174,11 +174,11 @@ def convert_diffusion_diffusers(
|
||||||
pipeline.unet,
|
pipeline.unet,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||||
device=ctx.training_device, dtype=torch_dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
torch.randn(2).to(device=ctx.training_device, dtype=torch_dtype),
|
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=ctx.training_device, dtype=torch_dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
unet_scale,
|
unet_scale,
|
||||||
),
|
),
|
||||||
|
@ -230,7 +230,7 @@ def convert_diffusion_diffusers(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=ctx.training_device, dtype=torch_dtype),
|
).to(device=ctx.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae" / "model.onnx",
|
output_path=output_path / "vae" / "model.onnx",
|
||||||
|
@ -255,7 +255,7 @@ def convert_diffusion_diffusers(
|
||||||
vae_encoder,
|
vae_encoder,
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||||
device=ctx.training_device, dtype=torch_dtype
|
device=ctx.training_device, dtype=dtype
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
@ -279,7 +279,7 @@ def convert_diffusion_diffusers(
|
||||||
model_args=(
|
model_args=(
|
||||||
torch.randn(
|
torch.randn(
|
||||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||||
).to(device=ctx.training_device, dtype=torch_dtype),
|
).to(device=ctx.training_device, dtype=dtype),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||||
|
|
|
@ -62,7 +62,7 @@ def blend_loras(
|
||||||
):
|
):
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = context.torch_dtype()
|
dtype = torch.float32
|
||||||
|
|
||||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||||
|
|
|
@ -22,7 +22,7 @@ def blend_textual_inversions(
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = context.numpy_dtype()
|
dtype = np.float32
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, weight, base_token, inversion_format in inversions:
|
for name, weight, base_token, inversion_format in inversions:
|
||||||
|
@ -131,11 +131,11 @@ def blend_textual_inversions(
|
||||||
for n in text_encoder.graph.initializer
|
for n in text_encoder.graph.initializer
|
||||||
if n.name == "text_model.embeddings.token_embedding.weight"
|
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||||
][0]
|
][0]
|
||||||
embedding_weights = numpy_helper.to_array(embedding_node)
|
base_weights = numpy_helper.to_array(embedding_node)
|
||||||
|
|
||||||
weights_dim = embedding_weights.shape[1]
|
weights_dim = base_weights.shape[1]
|
||||||
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
||||||
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
|
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
||||||
|
|
||||||
for token, weights in embeds.items():
|
for token, weights in embeds.items():
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
@ -149,7 +149,7 @@ def blend_textual_inversions(
|
||||||
== "text_model.embeddings.token_embedding.weight"
|
== "text_model.embeddings.token_embedding.weight"
|
||||||
):
|
):
|
||||||
new_initializer = numpy_helper.from_array(
|
new_initializer = numpy_helper.from_array(
|
||||||
embedding_weights.astype(dtype), embedding_node.name
|
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
||||||
)
|
)
|
||||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||||
del text_encoder.graph.initializer[i]
|
del text_encoder.graph.initializer[i]
|
||||||
|
|
|
@ -360,15 +360,15 @@ class UNetWrapper(object):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
timestep_dtype = timestep.dtype
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
|
||||||
if "onnx-fp16" in self.server.optimizations:
|
if sample.dtype != timestep.dtype:
|
||||||
logger.info("converting UNet sample to ONNX fp16")
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
sample = sample.astype(np.float16)
|
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
|
||||||
elif sample.dtype != timestep.dtype:
|
|
||||||
logger.info("converting UNet sample to timestep dtype")
|
|
||||||
sample = sample.astype(timestep.dtype)
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
|
if sample.dtype != timestep.dtype:
|
||||||
|
logger.trace("converting UNet hidden states to timestep dtype")
|
||||||
|
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
|
|
|
@ -86,9 +86,3 @@ class ServerContext:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
else:
|
else:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
def numpy_dtype(self):
|
|
||||||
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
|
|
||||||
return np.float16
|
|
||||||
else:
|
|
||||||
return np.float32
|
|
Loading…
Reference in New Issue