Two-module pipeline for AI companion safety: - Module B: context-aware risk detector with CrossAttention fusion - Module C: PPO-based adaptive intervention policy Includes CompanionRisk Taxonomy (10 primary + 14 fine-grained labels), dataset generation/annotation pipeline, training scripts, and eval suite. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
"""
|
|
Evaluation metrics for detection and intervention tasks.
|
|
"""
|
|
|
|
import numpy as np
|
|
from sklearn.metrics import (
|
|
classification_report,
|
|
f1_score,
|
|
precision_score,
|
|
recall_score,
|
|
cohen_kappa_score,
|
|
)
|
|
from typing import List, Dict
|
|
|
|
|
|
def detection_metrics(
|
|
y_true: List[int],
|
|
y_pred: List[int],
|
|
l_true: List[int] = None,
|
|
l_pred: List[int] = None,
|
|
fine_true: np.ndarray = None,
|
|
fine_pred: np.ndarray = None,
|
|
) -> Dict:
|
|
"""Compute all detection task metrics."""
|
|
results = {}
|
|
|
|
# Binary risk classification
|
|
results["binary_f1"] = f1_score(y_true, y_pred, average="binary")
|
|
results["high_risk_recall"] = recall_score(y_true, y_pred, pos_label=1)
|
|
results["high_risk_precision"] = precision_score(y_true, y_pred, pos_label=1)
|
|
results["false_negative_rate"] = 1.0 - results["high_risk_recall"]
|
|
|
|
# Risk level classification
|
|
if l_true is not None and l_pred is not None:
|
|
results["level_macro_f1"] = f1_score(l_true, l_pred, average="macro")
|
|
results["level_weighted_f1"] = f1_score(l_true, l_pred, average="weighted")
|
|
|
|
# Fine-grained multi-label
|
|
if fine_true is not None and fine_pred is not None:
|
|
results["fine_macro_f1"] = f1_score(fine_true, fine_pred, average="macro")
|
|
results["fine_weighted_f1"] = f1_score(fine_true, fine_pred, average="weighted")
|
|
results["fine_per_label_f1"] = f1_score(fine_true, fine_pred, average=None).tolist()
|
|
|
|
return results
|
|
|
|
|
|
def intervention_metrics(
|
|
y_risk_true: List[int],
|
|
l_risk_true: List[int],
|
|
a_pred: List[int],
|
|
a_recommend: List[int] = None,
|
|
) -> Dict:
|
|
"""Compute intervention task metrics."""
|
|
results = {}
|
|
|
|
y_risk_true = np.array(y_risk_true)
|
|
l_risk_true = np.array(l_risk_true)
|
|
a_pred = np.array(a_pred)
|
|
|
|
high_risk_mask = l_risk_true >= 3
|
|
safe_mask = l_risk_true == 0
|
|
|
|
# Intervention recall on high-risk samples (l=3 or l=4)
|
|
if high_risk_mask.sum() > 0:
|
|
correct_intervention = (a_pred[high_risk_mask] >= 2) # REWRITE, REJECT, or CRISIS
|
|
results["intervention_recall_high"] = correct_intervention.mean()
|
|
|
|
# Over-intervention rate on safe samples
|
|
if safe_mask.sum() > 0:
|
|
over_intervened = (a_pred[safe_mask] >= 2)
|
|
results["over_intervention_rate"] = over_intervened.mean()
|
|
|
|
# Action distribution
|
|
action_counts = np.bincount(a_pred, minlength=5)
|
|
results["action_distribution"] = (action_counts / len(a_pred)).tolist()
|
|
|
|
# Crisis precision: among CRISIS actions, how many are truly critical
|
|
crisis_mask = a_pred == 4
|
|
if crisis_mask.sum() > 0:
|
|
results["crisis_precision"] = (l_risk_true[crisis_mask] == 4).mean()
|
|
else:
|
|
results["crisis_precision"] = float("nan")
|
|
|
|
# Safety-UX F-score: harmonic mean of intervention_recall_high and (1 - over_intervention_rate)
|
|
if "intervention_recall_high" in results and "over_intervention_rate" in results:
|
|
recall = results["intervention_recall_high"]
|
|
ux_score = 1.0 - results["over_intervention_rate"]
|
|
if recall + ux_score > 0:
|
|
results["safety_ux_fscore"] = 2 * recall * ux_score / (recall + ux_score)
|
|
else:
|
|
results["safety_ux_fscore"] = 0.0
|
|
|
|
return results
|
|
|
|
|
|
def inter_annotator_agreement(labels_a: List, labels_b: List) -> float:
|
|
"""Compute Cohen's kappa between two annotators."""
|
|
return cohen_kappa_score(labels_a, labels_b)
|