Files
CompanionGuard-RL/src/utils/metrics.py

99 lines
3.3 KiB
Python
Raw Normal View History

"""
Evaluation metrics for detection and intervention tasks.
"""
import numpy as np
from sklearn.metrics import (
classification_report,
f1_score,
precision_score,
recall_score,
cohen_kappa_score,
)
from typing import List, Dict
def detection_metrics(
y_true: List[int],
y_pred: List[int],
l_true: List[int] = None,
l_pred: List[int] = None,
fine_true: np.ndarray = None,
fine_pred: np.ndarray = None,
) -> Dict:
"""Compute all detection task metrics."""
results = {}
# Binary risk classification
results["binary_f1"] = f1_score(y_true, y_pred, average="binary")
results["high_risk_recall"] = recall_score(y_true, y_pred, pos_label=1)
results["high_risk_precision"] = precision_score(y_true, y_pred, pos_label=1)
results["false_negative_rate"] = 1.0 - results["high_risk_recall"]
# Risk level classification
if l_true is not None and l_pred is not None:
results["level_macro_f1"] = f1_score(l_true, l_pred, average="macro")
results["level_weighted_f1"] = f1_score(l_true, l_pred, average="weighted")
# Fine-grained multi-label
if fine_true is not None and fine_pred is not None:
results["fine_macro_f1"] = f1_score(fine_true, fine_pred, average="macro")
results["fine_weighted_f1"] = f1_score(fine_true, fine_pred, average="weighted")
results["fine_per_label_f1"] = f1_score(fine_true, fine_pred, average=None).tolist()
return results
def intervention_metrics(
y_risk_true: List[int],
l_risk_true: List[int],
a_pred: List[int],
a_recommend: List[int] = None,
) -> Dict:
"""Compute intervention task metrics."""
results = {}
y_risk_true = np.array(y_risk_true)
l_risk_true = np.array(l_risk_true)
a_pred = np.array(a_pred)
high_risk_mask = l_risk_true >= 3
safe_mask = l_risk_true == 0
# Intervention recall on high-risk samples (l=3 or l=4)
if high_risk_mask.sum() > 0:
correct_intervention = (a_pred[high_risk_mask] >= 2) # REWRITE, REJECT, or CRISIS
results["intervention_recall_high"] = correct_intervention.mean()
# Over-intervention rate on safe samples
if safe_mask.sum() > 0:
over_intervened = (a_pred[safe_mask] >= 2)
results["over_intervention_rate"] = over_intervened.mean()
# Action distribution
action_counts = np.bincount(a_pred, minlength=5)
results["action_distribution"] = (action_counts / len(a_pred)).tolist()
# Crisis precision: among CRISIS actions, how many are truly critical
crisis_mask = a_pred == 4
if crisis_mask.sum() > 0:
results["crisis_precision"] = (l_risk_true[crisis_mask] == 4).mean()
else:
results["crisis_precision"] = float("nan")
# Safety-UX F-score: harmonic mean of intervention_recall_high and (1 - over_intervention_rate)
if "intervention_recall_high" in results and "over_intervention_rate" in results:
recall = results["intervention_recall_high"]
ux_score = 1.0 - results["over_intervention_rate"]
if recall + ux_score > 0:
results["safety_ux_fscore"] = 2 * recall * ux_score / (recall + ux_score)
else:
results["safety_ux_fscore"] = 0.0
return results
def inter_annotator_agreement(labels_a: List, labels_b: List) -> float:
"""Compute Cohen's kappa between two annotators."""
return cohen_kappa_score(labels_a, labels_b)