move tensor logs to trace level
This commit is contained in:
parent
0e43acc0f7
commit
e6d7d30225
|
@ -78,13 +78,13 @@ def expand_prompt(
|
|||
logger.trace("encoding group: %s", group.shape)
|
||||
|
||||
text_result = self.text_encoder(input_ids=group.astype(np.int32))
|
||||
logger.info("text encoder result: %s", text_result)
|
||||
logger.trace("text encoder produced %s outputs: %s", len(text_result), text_result)
|
||||
|
||||
last_state, _pooled_output, *hidden_states = text_result
|
||||
if skip_clip_states > 0:
|
||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||
norm_state = layer_norm(torch.from_numpy(hidden_states[-skip_clip_states]).detach())
|
||||
logger.info("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||
logger.trace("normalized results after skipping %s layers: %s", skip_clip_states, norm_state.shape)
|
||||
group_embeds.append(norm_state)
|
||||
else:
|
||||
group_embeds.append(last_state)
|
||||
|
|
Loading…
Reference in New Issue