Update entropy_decoding.py

This commit is contained in:
Asankhaya Sharma
2024-10-28 07:21:32 +08:00
parent 129ac8090a
commit 97265191b0

View File

@@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
return entropy, varentropy return entropy, varentropy
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]: def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
# attention_weights are already probabilities (post-softmax)
attention_probs = attention_weights attention_probs = attention_weights
# Calculate entropy # Calculate entropy
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
# Calculate variance of entropy # Calculate variance of entropy with unbiased=False to avoid df issues
attn_varentropy = torch.var(attn_entropy, dim=-1) # Also add a check for singleton dimensions
if attn_entropy.size(-1) > 1:
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
else:
attn_varentropy = torch.zeros_like(attn_entropy)
attn_varentropy = torch.where(torch.isnan(attn_varentropy), attn_varentropy = torch.where(torch.isnan(attn_varentropy),
torch.zeros_like(attn_varentropy), torch.zeros_like(attn_varentropy),
attn_varentropy) attn_varentropy)
# Calculate mean attention and agreement # Rest remains the same
mean_attention = torch.mean(attention_probs, dim=1) mean_attention = torch.mean(attention_probs, dim=1)
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2)) agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
# For interaction strength, we can use log probabilities to approximate the original scores
# This maintains the relative relationships while providing a reasonable proxy for attention strength
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0)) attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3)) interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))