move tensor logs to trace level
This commit is contained in:
parent
0e43acc0f7
commit
e6d7d30225
|
@ -78,13 +78,13 @@ def expand_prompt(
|
||||||
logger.trace("encoding group: %s", group.shape)
|
logger.trace("encoding group: %s", group.shape)
|
||||||
|
|
||||||
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
||||||
logger.info("text encoder result: %s", text_result)
|
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
|
||||||
|
|
||||||
last_state, _pooled_output, *hidden_states = text_result
|
last_state, _pooled_output, *hidden_states = text_result
|
||||||
if skip_clip_states > 0:
|
if skip_clip_states > 0:
|
||||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||||
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
||||||
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||||
group_embeds.append(norm_state)
|
group_embeds.append(norm_state)
|
||||||
else:
|
else:
|
||||||
group_embeds.append(last_state)
|
group_embeds.append(last_state)
|
||||||
|
|
Loading…
Reference in New Issue