fix(api): convert back to model format after blending, convert samples as needed (#274)
This commit is contained in:
parent
0315a8cbc6
commit
c8aad8554e
|
@ -97,8 +97,8 @@ def convert_diffusion_diffusers(
|
|||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
|
||||
torch_dtype = ctx.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||
dtype = ctx.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||
|
||||
dest_path = path.join(ctx.model_path, name)
|
||||
model_index = path.join(dest_path, "model_index.json")
|
||||
|
@ -117,7 +117,7 @@ def convert_diffusion_diffusers(
|
|||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
source,
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=ctx.token,
|
||||
).to(ctx.training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
@ -174,11 +174,11 @@ def convert_diffusion_diffusers(
|
|||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=ctx.training_device, dtype=torch_dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
torch.randn(2).to(device=ctx.training_device, dtype=torch_dtype),
|
||||
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=ctx.training_device, dtype=torch_dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
unet_scale,
|
||||
),
|
||||
|
@ -230,7 +230,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=ctx.training_device, dtype=torch_dtype),
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae" / "model.onnx",
|
||||
|
@ -255,7 +255,7 @@ def convert_diffusion_diffusers(
|
|||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=ctx.training_device, dtype=torch_dtype
|
||||
device=ctx.training_device, dtype=dtype
|
||||
),
|
||||
False,
|
||||
),
|
||||
|
@ -279,7 +279,7 @@ def convert_diffusion_diffusers(
|
|||
model_args=(
|
||||
torch.randn(
|
||||
1, vae_latent_channels, unet_sample_size, unet_sample_size
|
||||
).to(device=ctx.training_device, dtype=torch_dtype),
|
||||
).to(device=ctx.training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
|
|
|
@ -62,7 +62,7 @@ def blend_loras(
|
|||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = context.torch_dtype()
|
||||
dtype = torch.float32
|
||||
|
||||
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
|
|
@ -22,7 +22,7 @@ def blend_textual_inversions(
|
|||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = context.numpy_dtype()
|
||||
dtype = np.float32
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, inversion_format in inversions:
|
||||
|
@ -131,11 +131,11 @@ def blend_textual_inversions(
|
|||
for n in text_encoder.graph.initializer
|
||||
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||
][0]
|
||||
embedding_weights = numpy_helper.to_array(embedding_node)
|
||||
base_weights = numpy_helper.to_array(embedding_node)
|
||||
|
||||
weights_dim = embedding_weights.shape[1]
|
||||
weights_dim = base_weights.shape[1]
|
||||
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
||||
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
|
||||
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
||||
|
||||
for token, weights in embeds.items():
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
|
@ -149,7 +149,7 @@ def blend_textual_inversions(
|
|||
== "text_model.embeddings.token_embedding.weight"
|
||||
):
|
||||
new_initializer = numpy_helper.from_array(
|
||||
embedding_weights.astype(dtype), embedding_node.name
|
||||
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
||||
)
|
||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||
del text_encoder.graph.initializer[i]
|
||||
|
|
|
@ -360,15 +360,15 @@ class UNetWrapper(object):
|
|||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
if "onnx-fp16" in self.server.optimizations:
|
||||
logger.info("converting UNet sample to ONNX fp16")
|
||||
sample = sample.astype(np.float16)
|
||||
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
||||
elif sample.dtype != timestep.dtype:
|
||||
logger.info("converting UNet sample to timestep dtype")
|
||||
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
|
|
|
@ -86,9 +86,3 @@ class ServerContext:
|
|||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def numpy_dtype(self):
|
||||
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
|
||||
return np.float16
|
||||
else:
|
||||
return np.float32
|
Loading…
Reference in New Issue