feat: initial CompanionGuard-RL framework
Two-module pipeline for AI companion safety: - Module B: context-aware risk detector with CrossAttention fusion - Module C: PPO-based adaptive intervention policy Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels), dataset generation/annotation pipeline, training scripts, and eval suite. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
98
src/utils/metrics.py
Normal file
98
src/utils/metrics.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Evaluation metrics for detection and intervention tasks.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
classification_report,
|
||||
f1_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
cohen_kappa_score,
|
||||
)
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def detection_metrics(
|
||||
y_true: List[int],
|
||||
y_pred: List[int],
|
||||
l_true: List[int] = None,
|
||||
l_pred: List[int] = None,
|
||||
fine_true: np.ndarray = None,
|
||||
fine_pred: np.ndarray = None,
|
||||
) -> Dict:
|
||||
"""Compute all detection task metrics."""
|
||||
results = {}
|
||||
|
||||
# Binary risk classification
|
||||
results["binary_f1"] = f1_score(y_true, y_pred, average="binary")
|
||||
results["high_risk_recall"] = recall_score(y_true, y_pred, pos_label=1)
|
||||
results["high_risk_precision"] = precision_score(y_true, y_pred, pos_label=1)
|
||||
results["false_negative_rate"] = 1.0 - results["high_risk_recall"]
|
||||
|
||||
# Risk level classification
|
||||
if l_true is not None and l_pred is not None:
|
||||
results["level_macro_f1"] = f1_score(l_true, l_pred, average="macro")
|
||||
results["level_weighted_f1"] = f1_score(l_true, l_pred, average="weighted")
|
||||
|
||||
# Fine-grained multi-label
|
||||
if fine_true is not None and fine_pred is not None:
|
||||
results["fine_macro_f1"] = f1_score(fine_true, fine_pred, average="macro")
|
||||
results["fine_weighted_f1"] = f1_score(fine_true, fine_pred, average="weighted")
|
||||
results["fine_per_label_f1"] = f1_score(fine_true, fine_pred, average=None).tolist()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def intervention_metrics(
|
||||
y_risk_true: List[int],
|
||||
l_risk_true: List[int],
|
||||
a_pred: List[int],
|
||||
a_recommend: List[int] = None,
|
||||
) -> Dict:
|
||||
"""Compute intervention task metrics."""
|
||||
results = {}
|
||||
|
||||
y_risk_true = np.array(y_risk_true)
|
||||
l_risk_true = np.array(l_risk_true)
|
||||
a_pred = np.array(a_pred)
|
||||
|
||||
high_risk_mask = l_risk_true >= 3
|
||||
safe_mask = l_risk_true == 0
|
||||
|
||||
# Intervention recall on high-risk samples (l=3 or l=4)
|
||||
if high_risk_mask.sum() > 0:
|
||||
correct_intervention = (a_pred[high_risk_mask] >= 2) # REWRITE, REJECT, or CRISIS
|
||||
results["intervention_recall_high"] = correct_intervention.mean()
|
||||
|
||||
# Over-intervention rate on safe samples
|
||||
if safe_mask.sum() > 0:
|
||||
over_intervened = (a_pred[safe_mask] >= 2)
|
||||
results["over_intervention_rate"] = over_intervened.mean()
|
||||
|
||||
# Action distribution
|
||||
action_counts = np.bincount(a_pred, minlength=5)
|
||||
results["action_distribution"] = (action_counts / len(a_pred)).tolist()
|
||||
|
||||
# Crisis precision: among CRISIS actions, how many are truly critical
|
||||
crisis_mask = a_pred == 4
|
||||
if crisis_mask.sum() > 0:
|
||||
results["crisis_precision"] = (l_risk_true[crisis_mask] == 4).mean()
|
||||
else:
|
||||
results["crisis_precision"] = float("nan")
|
||||
|
||||
# Safety-UX F-score: harmonic mean of intervention_recall_high and (1 - over_intervention_rate)
|
||||
if "intervention_recall_high" in results and "over_intervention_rate" in results:
|
||||
recall = results["intervention_recall_high"]
|
||||
ux_score = 1.0 - results["over_intervention_rate"]
|
||||
if recall + ux_score > 0:
|
||||
results["safety_ux_fscore"] = 2 * recall * ux_score / (recall + ux_score)
|
||||
else:
|
||||
results["safety_ux_fscore"] = 0.0
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def inter_annotator_agreement(labels_a: List, labels_b: List) -> float:
|
||||
"""Compute Cohen's kappa between two annotators."""
|
||||
return cohen_kappa_score(labels_a, labels_b)
|
||||
Reference in New Issue
Block a user