1
0
Fork 0

move tensor logs to trace level

This commit is contained in:
Sean Sube 2023-03-19 09:11:55 -05:00
parent 0e43acc0f7
commit e6d7d30225
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 2 deletions

View File

@ -78,13 +78,13 @@ def expand_prompt(
logger.trace("encoding group: %s", group.shape) logger.trace("encoding group: %s", group.shape)
text_result = self.text_encoder(input_ids=group.astype(np.int32)) text_result = self.text_encoder(input_ids=group.astype(np.int32))
logger.info("text encoder result: %s", text_result) logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
last_state, _pooled_output, *hidden_states = text_result last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0: if skip_clip_states > 0:
layer_norm = torch.nn.LayerNorm(last_state.shape[2]) layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach()) norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape) logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
group_embeds.append(norm_state) group_embeds.append(norm_state)
else: else:
group_embeds.append(last_state) group_embeds.append(last_state)