mirror of
https://github.com/codelion/optillm.git
synced 2025-05-28 09:39:38 +03:00
Update entropy_decoding.py
This commit is contained in:
@@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup
|
|||||||
return entropy, varentropy
|
return entropy, varentropy
|
||||||
|
|
||||||
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
|
def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||||
# attention_weights are already probabilities (post-softmax)
|
|
||||||
attention_probs = attention_weights
|
attention_probs = attention_weights
|
||||||
|
|
||||||
# Calculate entropy
|
# Calculate entropy
|
||||||
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
|
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
|
||||||
|
|
||||||
# Calculate variance of entropy
|
# Calculate variance of entropy with unbiased=False to avoid df issues
|
||||||
attn_varentropy = torch.var(attn_entropy, dim=-1)
|
# Also add a check for singleton dimensions
|
||||||
|
if attn_entropy.size(-1) > 1:
|
||||||
|
attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False)
|
||||||
|
else:
|
||||||
|
attn_varentropy = torch.zeros_like(attn_entropy)
|
||||||
|
|
||||||
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
|
attn_varentropy = torch.where(torch.isnan(attn_varentropy),
|
||||||
torch.zeros_like(attn_varentropy),
|
torch.zeros_like(attn_varentropy),
|
||||||
attn_varentropy)
|
attn_varentropy)
|
||||||
|
|
||||||
# Calculate mean attention and agreement
|
# Rest remains the same
|
||||||
mean_attention = torch.mean(attention_probs, dim=1)
|
mean_attention = torch.mean(attention_probs, dim=1)
|
||||||
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
|
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
|
||||||
|
|
||||||
# For interaction strength, we can use log probabilities to approximate the original scores
|
|
||||||
# This maintains the relative relationships while providing a reasonable proxy for attention strength
|
|
||||||
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
|
attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0))
|
||||||
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
|
interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user