mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
Update entropy_decoding.py
This commit is contained in:
@@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
|
||||
return entropy, varentropy
|
||||
|
||||
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
# attention_weights are already probabilities (post-softmax)
|
||||
attention_probs = attention_weights
|
||||
|
||||
# Calculate entropy
|
||||
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
|
||||
|
||||
# Calculate variance of entropy
|
||||
attn_varentropy = torch.var(attn_entropy, dim=-1)
|
||||
# Calculate variance of entropy with unbiased=False to avoid df issues
|
||||
# Also add a check for singleton dimensions
|
||||
if attn_entropy.size(-1) > 1:
|
||||
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
|
||||
else:
|
||||
attn_varentropy = torch.zeros_like(attn_entropy)
|
||||
|
||||
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
|
||||
torch.zeros_like(attn_varentropy),
|
||||
attn_varentropy)
|
||||
|
||||
# Calculate mean attention and agreement
|
||||
# Rest remains the same
|
||||
mean_attention = torch.mean(attention_probs, dim=1)
|
||||
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
|
||||
|
||||
# For interaction strength, we can use log probabilities to approximate the original scores
|
||||
# This maintains the relative relationships while providing a reasonable proxy for attention strength
|
||||
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
|
||||
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user