chore: initial commit — unified project repo

Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-14 11:28:42 +08:00
commit bd1f51c496
85 changed files with 20568 additions and 0 deletions

0
code/src/__init__.py Normal file
View File

View File

218
code/src/models/detector.py Normal file
View File

@@ -0,0 +1,218 @@
"""
Module B: Context-aware Risk Detector.
Architecture:
1. Encode persona, context (history+user_input), response separately
2. Fuse via CrossAttention(response, [persona; context])
3. Multi-task classification heads:
- Binary risk (sigmoid)
- Risk level 0-4 (softmax)
- Primary category R1-R10 (softmax)
- Fine-grained 14-label (sigmoid multi-label)
Returns e_P_pool and e_H_pool for downstream RL state construction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from src.models.encoder import TextEncoder, ContextAwareFusion
from src.utils.taxonomy import NUM_PRIMARY, NUM_FINE, NUM_RISK_LEVELS
class CompanionRiskDetector(nn.Module):
def __init__(
self,
model_name: str = "hfl/chinese-macbert-large",
hidden_size: int = 768,
num_heads: int = 8,
dropout: float = 0.1,
use_lora: bool = False,
):
super().__init__()
self.encoder = TextEncoder(
model_name=model_name,
hidden_size=hidden_size,
use_lora=use_lora,
)
self.fusion = ContextAwareFusion(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
)
self.dropout = nn.Dropout(dropout)
# Classification heads
self.binary_head = nn.Linear(hidden_size, 1)
self.level_head = nn.Linear(hidden_size, NUM_RISK_LEVELS)
self.primary_head = nn.Linear(hidden_size, NUM_PRIMARY)
self.fine_head = nn.Linear(hidden_size, NUM_FINE)
def _mean_pool(self, hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Mean-pool token representations using attention mask."""
m = mask.unsqueeze(-1).float()
return (hidden * m).sum(1) / m.sum(1).clamp(min=1e-9)
def _build_context_padding_mask(
self,
persona_mask: torch.Tensor,
context_mask: torch.Tensor,
) -> torch.Tensor:
"""Build boolean padding mask for CrossAttention (True = ignore position)."""
return torch.cat([persona_mask == 0, context_mask == 0], dim=1)
def forward(
self,
persona_input_ids: torch.Tensor,
persona_attention_mask: torch.Tensor,
context_input_ids: torch.Tensor,
context_attention_mask: torch.Tensor,
response_input_ids: torch.Tensor,
response_attention_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
# Encode all three streams — [B, seq_len, H]
persona_h = self.encoder(persona_input_ids, persona_attention_mask)
context_h = self.encoder(context_input_ids, context_attention_mask)
response_h = self.encoder(response_input_ids, response_attention_mask)
# Separate pooled representations for RL state
e_P_pool = self._mean_pool(persona_h, persona_attention_mask) # [B, H]
e_H_pool = self._mean_pool(context_h, context_attention_mask) # [B, H]
# CrossAttention: response queries [persona; context] as relational context
combined_context = torch.cat([persona_h, context_h], dim=1)
combined_mask = self._build_context_padding_mask(persona_attention_mask, context_attention_mask)
fused = self.fusion(response_h, combined_context, combined_mask)
# Pool fused representation
e_fused = self._mean_pool(fused, response_attention_mask)
e_fused = self.dropout(e_fused)
return {
"y_risk": self.binary_head(e_fused).squeeze(-1), # [B]
"l_risk": self.level_head(e_fused), # [B, 5]
"c_primary": self.primary_head(e_fused), # [B, 10]
"c_fine": self.fine_head(e_fused), # [B, 14]
"e_fused": e_fused, # [B, H]
"e_P_pool": e_P_pool, # [B, H]
"e_H_pool": e_H_pool, # [B, H]
}
def compute_loss(
self,
logits: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
weights: Dict[str, float] = None,
fine_pos_weight: Optional[torch.Tensor] = None,
fine_risky_only: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
fine_pos_weight: shape [NUM_FINE], class-specific positive weights for BCEWithLogitsLoss.
Computed from training set: pos_weight[i] = (N - pos_i) / pos_i.
Strongly recommended for next training round to handle rare labels
like Romanticization/CoRumination (pos_weight ≈ 25.8).
fine_risky_only: if True, compute fine-label loss only on y_risk=1 samples.
Safe samples have empty c_fine by design; including them
in the loss teaches the model to predict all-negative, causing
fine_macro_f1 ≈ 0 at evaluation time.
"""
if weights is None:
weights = {"binary": 1.0, "level": 1.0, "primary": 1.0, "fine": 1.0}
loss_parts = {}
loss_binary = F.binary_cross_entropy_with_logits(
logits["y_risk"], targets["y_risk"].float()
)
loss_parts["loss_binary"] = loss_binary
loss_level = F.cross_entropy(logits["l_risk"], targets["l_risk"].long())
loss_parts["loss_level"] = loss_level
# Only compute primary category loss for samples with a valid category
# c_primary target is one-hot; samples with c_primary = "None" have all-zero vectors
primary_valid_mask = targets["c_primary"].sum(-1) > 0 # [B]
if primary_valid_mask.any():
primary_targets = targets["c_primary"][primary_valid_mask].argmax(-1)
primary_logits = logits["c_primary"][primary_valid_mask]
loss_primary = F.cross_entropy(primary_logits, primary_targets)
else:
loss_primary = torch.tensor(0.0, device=logits["c_primary"].device)
loss_parts["loss_primary"] = loss_primary
# Fine-grained multi-label loss
# Optional: restrict to risky samples to avoid teaching all-negative on safe samples
if fine_risky_only:
risky_mask = targets["y_risk"] > 0.5 # [B]
if risky_mask.any():
fine_logits_masked = logits["c_fine"][risky_mask]
fine_targets_masked = targets["c_fine"][risky_mask].float()
loss_fine = F.binary_cross_entropy_with_logits(
fine_logits_masked, fine_targets_masked,
pos_weight=fine_pos_weight,
)
else:
loss_fine = torch.tensor(0.0, device=logits["c_fine"].device)
else:
loss_fine = F.binary_cross_entropy_with_logits(
logits["c_fine"], targets["c_fine"].float(),
pos_weight=fine_pos_weight,
)
loss_parts["loss_fine"] = loss_fine
total = (
weights["binary"] * loss_binary
+ weights["level"] * loss_level
+ weights["primary"] * loss_primary
+ weights["fine"] * loss_fine
)
return total, loss_parts
@torch.no_grad()
def predict(
self,
persona_input_ids: torch.Tensor,
persona_attention_mask: torch.Tensor,
context_input_ids: torch.Tensor,
context_attention_mask: torch.Tensor,
response_input_ids: torch.Tensor,
response_attention_mask: torch.Tensor,
binary_threshold: float = 0.5,
fine_threshold: float = 0.4,
) -> Dict:
logits = self.forward(
persona_input_ids, persona_attention_mask,
context_input_ids, context_attention_mask,
response_input_ids, response_attention_mask,
)
d_score = torch.sigmoid(logits["y_risk"])
y_risk = (d_score >= binary_threshold).long()
l_risk = logits["l_risk"].argmax(-1)
c_primary = logits["c_primary"].argmax(-1)
c_primary_probs = torch.softmax(logits["c_primary"], dim=-1)
c_fine_probs = torch.sigmoid(logits["c_fine"]) # [B, NUM_FINE] continuous scores
c_fine = (c_fine_probs >= fine_threshold).float() # [B, NUM_FINE] binary predictions
return {
"y_risk": y_risk,
"l_risk": l_risk,
"c_primary": c_primary,
"c_primary_probs": c_primary_probs,
# c_fine: binary predictions (already thresholded) — use this in evaluate.py
"c_fine": c_fine,
# c_fine_probs: continuous sigmoid scores — use for ranking or custom thresholds
"c_fine_probs": c_fine_probs,
"d_score": d_score,
"e_fused": logits["e_fused"],
"e_P_pool": logits["e_P_pool"],
"e_H_pool": logits["e_H_pool"],
}

105
code/src/models/encoder.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Text encoders for Module B (Context-aware Risk Detector).
Supports:
- MacBERT-large (lightweight Chinese baseline)
- Qwen2.5-7B with LoRA (full-scale Chinese)
- LLaMA-3.1-8B with LoRA (multilingual)
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
from typing import Optional
class TextEncoder(nn.Module):
"""Shared backbone encoder for persona, context, and response."""
def __init__(
self,
model_name: str,
hidden_size: int = 768,
use_lora: bool = False,
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
freeze_base: bool = False,
):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.actual_hidden = self.backbone.config.hidden_size
if use_lora:
lora_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj", "query", "value"],
)
self.backbone = get_peft_model(self.backbone, lora_config)
elif freeze_base:
for param in self.backbone.parameters():
param.requires_grad = False
# Project to uniform hidden_size if needed
self.proj = (
nn.Linear(self.actual_hidden, hidden_size)
if self.actual_hidden != hidden_size
else nn.Identity()
)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Returns [batch, seq_len, hidden_size]."""
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
hidden = outputs.last_hidden_state
return self.proj(hidden)
def pool(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Returns mean-pooled [batch, hidden_size]."""
hidden = self.forward(input_ids, attention_mask)
mask = attention_mask.unsqueeze(-1).float()
return (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
class ContextAwareFusion(nn.Module):
"""
CrossAttention fusion: response as query, [persona; history] as key/value.
Captures risk signals in response conditioned on relational context.
"""
def __init__(self, hidden_size: int = 768, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.layer_norm = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size * 4, hidden_size),
)
self.ffn_norm = nn.LayerNorm(hidden_size)
def forward(
self,
response_hidden: torch.Tensor, # [B, R_len, H]
context_hidden: torch.Tensor, # [B, C_len, H] (persona + history concat)
context_key_padding_mask: Optional[torch.Tensor] = None, # [B, C_len]
) -> torch.Tensor:
"""Returns [B, R_len, H] — response enriched with context signals."""
attn_out, _ = self.cross_attn(
query=response_hidden,
key=context_hidden,
value=context_hidden,
key_padding_mask=context_key_padding_mask,
)
response_hidden = self.layer_norm(response_hidden + attn_out)
ffn_out = self.ffn(response_hidden)
return self.ffn_norm(response_hidden + ffn_out)

View File

@@ -0,0 +1,220 @@
"""
Module C: RL Intervention Policy — Actor-Critic network for PPO.
Observation vector (flat, 2*H+17 dim, e.g. 2065 for detector_hidden=1024):
[d_score(1) | l_risk_onehot(5) | c_primary_probs(10) |
e_H_pool(H) | e_P_pool(H) | t_norm(1)]
The network first parses this flat vector through _encode_obs() →
StateEncoder → 256-dim latent, then feeds the latent to actor/critic heads.
Action: {PASS=0, WARN=1, REWRITE=2, REJECT=3, CRISIS=4}
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict
from src.utils.taxonomy import NUM_ACTIONS, NUM_PRIMARY, NUM_RISK_LEVELS
class StateEncoder(nn.Module):
"""
Encodes the structured state components for the RL policy.
Input components (passed separately, not as flat vector):
d_score : [B, 1]
l_risk : [B] — integer level index (0-4), embedded
c_primary_probs: [B, NUM_PRIMARY]
e_H_pool : [B, detector_hidden]
e_P_pool : [B, detector_hidden]
t_norm : [B, 1]
Output: [B, state_hidden]
"""
def __init__(
self,
detector_hidden: int = 1024,
level_emb_dim: int = 16,
state_hidden: int = 256,
dropout: float = 0.1,
):
super().__init__()
self.level_emb = nn.Embedding(NUM_RISK_LEVELS, level_emb_dim)
# d_score(1) + level_emb(16) + c_primary_probs(10) + e_H(H) + e_P(H) + t_norm(1)
state_dim = 1 + level_emb_dim + NUM_PRIMARY + detector_hidden * 2 + 1
self.mlp = nn.Sequential(
nn.Linear(state_dim, state_hidden * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden * 2, state_hidden),
nn.ReLU(),
)
self.out_dim = state_hidden
def forward(
self,
d_score: torch.Tensor, # [B, 1]
l_risk: torch.Tensor, # [B] — integer
c_primary_probs: torch.Tensor, # [B, NUM_PRIMARY]
e_H_pool: torch.Tensor, # [B, H]
e_P_pool: torch.Tensor, # [B, H]
t_norm: torch.Tensor, # [B, 1]
) -> torch.Tensor:
level_emb = self.level_emb(l_risk) # [B, 16]
state = torch.cat(
[d_score, level_emb, c_primary_probs, e_H_pool, e_P_pool, t_norm], dim=-1
)
return self.mlp(state) # [B, state_hidden]
class InterventionAgent(nn.Module):
"""
Actor-Critic network for PPO-based intervention policy.
Actor: π(a | s) = softmax(MLP(encoded_state))
Critic: V(s) = MLP(encoded_state)
All public methods (get_action, evaluate_actions, behavior_clone_loss, forward)
accept the raw flat observation vector (2*H+17 dim) — _encode_obs() is called
internally to parse and encode it before the actor/critic heads.
"""
def __init__(
self,
detector_hidden: int = 1024,
state_hidden: int = 256,
dropout: float = 0.1,
):
super().__init__()
self.detector_hidden = detector_hidden
self.state_encoder = StateEncoder(
detector_hidden=detector_hidden,
state_hidden=state_hidden,
dropout=dropout,
)
self.actor = nn.Sequential(
nn.Linear(state_hidden, state_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden, NUM_ACTIONS),
)
self.critic = nn.Sequential(
nn.Linear(state_hidden, state_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden, 1),
)
# ── obs parsing ────────────────────────────────────────────────────────
def _encode_obs(self, obs: torch.Tensor) -> torch.Tensor:
"""
Parse a flat observation vector and encode it through StateEncoder.
Observation layout (2*H+17 dim):
obs[:, 0] → d_score
obs[:, 1 : 1+NUM_RISK_LEVELS] → l_risk one-hot (→ argmax for embedding)
obs[:, 1+NUM_RISK_LEVELS :
1+NUM_RISK_LEVELS+NUM_PRIMARY] → c_primary_probs
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY :
1+NUM_RISK_LEVELS+NUM_PRIMARY+H] → e_H_pool
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY+H :
1+NUM_RISK_LEVELS+NUM_PRIMARY+2*H] → e_P_pool
obs[:, -1] → t_norm
Args:
obs: [B, 2*H+17] float tensor on any device
Returns:
encoded: [B, state_hidden]
"""
H = self.detector_hidden
ptr = 0
d_score = obs[:, ptr : ptr + 1]; ptr += 1
l_risk_onehot = obs[:, ptr : ptr + NUM_RISK_LEVELS]; ptr += NUM_RISK_LEVELS
c_primary = obs[:, ptr : ptr + NUM_PRIMARY]; ptr += NUM_PRIMARY
e_H_pool = obs[:, ptr : ptr + H]; ptr += H
e_P_pool = obs[:, ptr : ptr + H]; ptr += H
t_norm = obs[:, ptr : ptr + 1]
# Convert one-hot → integer index for the embedding lookup
l_risk_idx = l_risk_onehot.argmax(dim=-1) # [B]
return self.state_encoder(d_score, l_risk_idx, c_primary, e_H_pool, e_P_pool, t_norm)
# ── public API ─────────────────────────────────────────────────────────
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
obs: raw flat observation [B, 2*H+17]
Returns:
(action_logits [B, NUM_ACTIONS], state_value [B, 1])
"""
encoded = self._encode_obs(obs)
return self.actor(encoded), self.critic(encoded)
def encode_state(
self,
d_score: torch.Tensor,
l_risk: torch.Tensor,
c_primary_probs: torch.Tensor,
e_H_pool: torch.Tensor,
e_P_pool: torch.Tensor,
t_norm: torch.Tensor,
) -> torch.Tensor:
"""Direct StateEncoder call with pre-parsed components (kept for external use)."""
return self.state_encoder(d_score, l_risk, c_primary_probs, e_H_pool, e_P_pool, t_norm)
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample (or argmax) an action from the policy.
Args:
obs: [B, 2*H+17] raw observation
Returns:
(action [B], log_prob [B], entropy [B], value [B])
"""
logits, value = self.forward(obs)
dist = torch.distributions.Categorical(logits=logits)
action = logits.argmax(-1) if deterministic else dist.sample()
return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Re-evaluate stored actions for PPO update.
Args:
obs: [B, 2*H+17] raw observation
actions: [B] integer action indices
Returns:
(log_prob [B], entropy [B], value [B])
"""
logits, value = self.forward(obs)
dist = torch.distributions.Categorical(logits=logits)
return dist.log_prob(actions), dist.entropy(), value.squeeze(-1)
def behavior_clone_loss(
self, obs: torch.Tensor, expert_actions: torch.Tensor
) -> torch.Tensor:
"""
Supervised cross-entropy loss for BC warm-up.
Args:
obs: [B, 2*H+17] raw observation
expert_actions: [B] integer expert action indices
"""
logits, _ = self.forward(obs)
return F.cross_entropy(logits, expert_actions)

0
code/src/rl/__init__.py Normal file
View File

View File

@@ -0,0 +1,165 @@
"""
Simulated intervention environment for CompanionGuard-RL.
Wraps the pre-processed dataset as a Gymnasium-compatible offline RL environment.
Each episode = one dataset sample (single-step MDP).
Observation:
d_score(1) | l_risk_onehot(5) | c_primary_probs(10) |
e_H_pool(H) | e_P_pool(H) | t_norm(1)
Action: Discrete(5) → {PASS, WARN, REWRITE, REJECT, CRISIS}
Reward: multi-objective safety reward from src.rl.reward
"""
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from typing import Dict, Tuple, Optional, Any, List
from src.rl.reward import compute_reward
from src.utils.taxonomy import NUM_ACTIONS, NUM_PRIMARY, NUM_RISK_LEVELS
from src.utils.preprocessing import build_obs_vector
class CompanionEnv(gym.Env):
"""
Offline simulated environment built from a pre-processed detector-annotated dataset.
Since each sample is a one-step MDP (the intervention is decided once per AI response),
every call to step() terminates the episode immediately (terminated=True).
The collect_rollout loop in PPOTrainer handles auto-resets correctly.
"""
metadata = {"render_modes": []}
def __init__(
self,
samples: List[Dict],
detector_hidden: int = 1024,
reward_weights: Optional[Dict] = None,
max_turns: int = 20,
):
super().__init__()
self.samples = samples
self.detector_hidden = detector_hidden
self.reward_weights = reward_weights
self.max_turns = max_turns
obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32
)
self.action_space = spaces.Discrete(NUM_ACTIONS)
self._current_sample: Optional[Dict] = None
def _get_obs(self) -> np.ndarray:
return build_obs_vector(self._current_sample, max_turns=self.max_turns)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict] = None,
) -> Tuple[np.ndarray, Dict]:
super().reset(seed=seed)
idx = self.np_random.integers(0, len(self.samples))
self._current_sample = self.samples[idx]
return self._get_obs(), {}
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
assert self._current_sample is not None, "Call reset() before step()"
sample = self._current_sample
from src.utils.taxonomy import ACTION_NAME_TO_ID, category_to_index, PRIMARY_CATEGORY_LIST
# Convert a_recommend string → int for alignment signal
a_rec_str = sample.get("a_recommend", "PASS")
a_rec_int = ACTION_NAME_TO_ID.get(a_rec_str, 0)
# Use ground-truth c_primary for reward (training time has GT; deployment uses det)
gt_category = sample.get("c_primary", "None")
if gt_category in PRIMARY_CATEGORY_LIST:
reward_cat_idx = category_to_index(gt_category)
else:
reward_cat_idx = int(sample.get("c_primary_idx", 0))
reward = compute_reward(
action=int(action),
y_risk=int(sample["y_risk"]),
l_risk=int(sample["l_risk"]),
c_primary_idx=reward_cat_idx,
a_recommend=a_rec_int,
)
# One-step MDP: always terminate
terminated = True
truncated = False
info = {
"y_risk": int(sample["y_risk"]),
"l_risk": int(sample["l_risk"]),
"a_recommend": sample.get("a_recommend", "PASS"),
"action_taken": action,
}
# Return current obs (episode is over, but Gymnasium requires a valid obs)
obs = self._get_obs()
return obs, float(reward), terminated, truncated, info
def render(self):
pass
class BatchCompanionEnv:
"""
Simple vectorized wrapper around multiple CompanionEnv instances.
Used for faster rollout collection in PPO.
"""
def __init__(
self,
samples: List[Dict],
n_envs: int = 8,
detector_hidden: int = 1024,
reward_weights: Optional[Dict] = None,
max_turns: int = 20,
):
self.n_envs = n_envs
self.envs = [
CompanionEnv(
samples=samples,
detector_hidden=detector_hidden,
reward_weights=reward_weights,
max_turns=max_turns,
)
for _ in range(n_envs)
]
def reset(self) -> np.ndarray:
obs_list = [env.reset()[0] for env in self.envs]
return np.stack(obs_list)
def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List]:
obs_list, reward_list, done_list, info_list = [], [], [], []
for env, action in zip(self.envs, actions):
obs, reward, terminated, truncated, info = env.step(int(action))
done = terminated or truncated
if done:
# Auto-reset
obs, _ = env.reset()
obs_list.append(obs)
reward_list.append(reward)
done_list.append(done)
info_list.append(info)
return (
np.stack(obs_list),
np.array(reward_list, dtype=np.float32),
np.array(done_list, dtype=bool),
info_list,
)

319
code/src/rl/ppo_trainer.py Normal file
View File

@@ -0,0 +1,319 @@
"""
PPO trainer for Module C: Intervention Policy.
Training stages:
Stage 1 (Supervised warm-up): behavior cloning from a_recommend labels
Stage 2 (PPO fine-tuning): optimize with multi-objective reward
PPO hyperparams (validated from prior work):
clip_eps=0.2, lr=3e-4, entropy_coef=0.01
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from typing import Dict, List, Optional
try:
import wandb
except ImportError:
wandb = None
from src.models.intervention_agent import InterventionAgent
class RolloutBuffer:
"""Stores PPO rollout trajectories."""
def __init__(self, buffer_size: int, obs_dim: int, device: str = "cpu"):
self.buffer_size = buffer_size
self.obs_dim = obs_dim
self.device = device
self.reset()
def reset(self):
self.obs = torch.zeros(self.buffer_size, self.obs_dim)
self.actions = torch.zeros(self.buffer_size, dtype=torch.long)
self.log_probs = torch.zeros(self.buffer_size)
self.rewards = torch.zeros(self.buffer_size)
self.values = torch.zeros(self.buffer_size)
self.dones = torch.zeros(self.buffer_size)
self.ptr = 0
self.full = False
def add(
self,
obs: torch.Tensor,
action: torch.Tensor,
log_prob: torch.Tensor,
reward: float,
value: torch.Tensor,
done: bool,
):
self.obs[self.ptr] = obs
self.actions[self.ptr] = action
self.log_probs[self.ptr] = log_prob
self.rewards[self.ptr] = float(reward)
self.values[self.ptr] = value
self.dones[self.ptr] = float(done)
self.ptr = (self.ptr + 1) % self.buffer_size
if self.ptr == 0:
self.full = True
def size(self) -> int:
return self.buffer_size if self.full else self.ptr
def compute_returns_and_advantages(
self, gamma: float = 0.99, gae_lambda: float = 0.95
):
n = self.size()
advantages = torch.zeros(n)
last_gae = 0.0
for t in reversed(range(n)):
if t + 1 < n:
next_value = self.values[t + 1].item() * (1.0 - self.dones[t + 1].item())
else:
next_value = 0.0
delta = (
self.rewards[t].item()
+ gamma * next_value
- self.values[t].item()
)
last_gae = delta + gamma * gae_lambda * (1.0 - self.dones[t].item()) * last_gae
advantages[t] = last_gae
returns = advantages + self.values[:n]
return advantages.to(self.device), returns.to(self.device)
def get(self) -> Dict[str, torch.Tensor]:
n = self.size()
return {
"obs": self.obs[:n].to(self.device),
"actions": self.actions[:n].to(self.device),
"log_probs": self.log_probs[:n].to(self.device),
"values": self.values[:n].to(self.device),
}
class PPOTrainer:
def __init__(
self,
agent: InterventionAgent,
obs_dim: int,
lr: float = 3e-4,
clip_eps: float = 0.2,
entropy_coef: float = 0.01,
value_coef: float = 0.5,
max_grad_norm: float = 0.5,
gamma: float = 0.99,
gae_lambda: float = 0.95,
n_epochs: int = 4,
batch_size: int = 64,
buffer_size: int = 2048,
device: str = "cpu",
use_wandb: bool = True,
):
self.agent = agent.to(device)
self.optimizer = optim.Adam(agent.parameters(), lr=lr)
self.device = device
self.clip_eps = clip_eps
self.entropy_coef = entropy_coef
self.value_coef = value_coef
self.max_grad_norm = max_grad_norm
self.gamma = gamma
self.gae_lambda = gae_lambda
self.n_epochs = n_epochs
self.batch_size = batch_size
self.use_wandb = use_wandb
self.buffer = RolloutBuffer(buffer_size, obs_dim, device)
def behavior_cloning_warmup(
self,
obs_tensor: torch.Tensor,
expert_actions: torch.Tensor,
n_epochs: int = 5,
lr: float = 1e-3,
) -> List[float]:
"""Stage 1: supervised pre-training to initialize policy."""
optimizer = optim.Adam(self.agent.parameters(), lr=lr)
losses = []
dataset = torch.utils.data.TensorDataset(obs_tensor, expert_actions)
loader = torch.utils.data.DataLoader(
dataset, batch_size=self.batch_size, shuffle=True
)
for epoch in range(n_epochs):
epoch_loss = 0.0
self.agent.train()
for obs_batch, act_batch in loader:
obs_batch = obs_batch.to(self.device)
act_batch = act_batch.to(self.device)
loss = self.agent.behavior_clone_loss(obs_batch, act_batch)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / max(len(loader), 1)
losses.append(avg_loss)
print(f"[BC] Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}")
if self.use_wandb:
wandb.log({"bc/loss": avg_loss, "bc/epoch": epoch + 1})
return losses
def ppo_update(
self, advantages: torch.Tensor, returns: torch.Tensor
) -> Dict[str, float]:
"""Single PPO update epoch across the current buffer."""
buffer_data = self.buffer.get()
obs = buffer_data["obs"]
actions = buffer_data["actions"]
old_log_probs = buffer_data["log_probs"]
total_pg_loss = 0.0
total_v_loss = 0.0
total_entropy = 0.0
n_updates = 0
self.agent.train()
indices = torch.randperm(len(obs), device=self.device)
for start in range(0, len(obs), self.batch_size):
idx = indices[start: start + self.batch_size]
batch_obs = obs[idx]
batch_actions = actions[idx]
batch_old_lp = old_log_probs[idx]
batch_adv = advantages[idx]
batch_returns = returns[idx]
# Normalize advantages within mini-batch
batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-8)
log_probs, entropy, values = self.agent.evaluate_actions(
batch_obs, batch_actions
)
ratio = torch.exp(log_probs - batch_old_lp)
pg_loss1 = -batch_adv * ratio
pg_loss2 = -batch_adv * ratio.clamp(
1.0 - self.clip_eps, 1.0 + self.clip_eps
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
v_loss = 0.5 * (values - batch_returns).pow(2).mean()
entropy_loss = -entropy.mean()
loss = (
pg_loss
+ self.value_coef * v_loss
+ self.entropy_coef * entropy_loss
)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
self.optimizer.step()
total_pg_loss += pg_loss.item()
total_v_loss += v_loss.item()
total_entropy += entropy.mean().item()
n_updates += 1
n = max(n_updates, 1)
return {
"pg_loss": total_pg_loss / n,
"v_loss": total_v_loss / n,
"entropy": total_entropy / n,
}
def collect_rollout(self, env, n_steps: int = 2048):
"""
Collect environment rollouts and fill buffer.
Compatible with Gymnasium API:
env.reset() → (obs, info)
env.step(action) → (obs, reward, terminated, truncated, info)
"""
self.buffer.reset()
self.agent.eval()
# Gymnasium reset returns (obs, info)
obs_np, _ = env.reset()
obs = torch.FloatTensor(obs_np).to(self.device)
for _ in range(n_steps):
with torch.no_grad():
action, log_prob, _, value = self.agent.get_action(obs.unsqueeze(0))
action = action.squeeze(0)
log_prob = log_prob.squeeze(0)
value = value.squeeze(0)
# Gymnasium step returns 5-tuple
next_obs_np, reward, terminated, truncated, _ = env.step(
int(action.cpu().item())
)
done = terminated or truncated
self.buffer.add(
obs.cpu(), action.cpu(), log_prob.cpu(), reward, value.cpu(), done
)
if done:
obs_np, _ = env.reset()
obs = torch.FloatTensor(obs_np).to(self.device)
else:
obs = torch.FloatTensor(next_obs_np).to(self.device)
def train(
self,
env,
total_timesteps: int = 100_000,
n_rollout_steps: int = 2048,
checkpoint_dir: str = "checkpoints/intervention",
save_interval: int = 10_000,
):
"""Full PPO training loop."""
os.makedirs(checkpoint_dir, exist_ok=True)
timestep = 0
update = 0
while timestep < total_timesteps:
self.collect_rollout(env, n_rollout_steps)
advantages, returns = self.buffer.compute_returns_and_advantages(
self.gamma, self.gae_lambda
)
metrics = {}
for _ in range(self.n_epochs):
metrics = self.ppo_update(advantages, returns)
timestep += n_rollout_steps
update += 1
print(
f"[PPO] Update {update} | Steps {timestep}/{total_timesteps} | "
f"PG: {metrics.get('pg_loss', 0):.4f}, "
f"V: {metrics.get('v_loss', 0):.4f}, "
f"Ent: {metrics.get('entropy', 0):.4f}"
)
if self.use_wandb:
wandb.log({
"ppo/pg_loss": metrics.get("pg_loss", 0),
"ppo/v_loss": metrics.get("v_loss", 0),
"ppo/entropy": metrics.get("entropy", 0),
"ppo/timestep": timestep,
"ppo/update": update,
})
if timestep % save_interval == 0:
ckpt_path = os.path.join(checkpoint_dir, f"agent_step{timestep}.pt")
torch.save(self.agent.state_dict(), ckpt_path)
print(f"Saved checkpoint: {ckpt_path}")
final_path = os.path.join(checkpoint_dir, "final.pt")
torch.save(self.agent.state_dict(), final_path)
print(f"Training complete. Final model: {final_path}")

146
code/src/rl/reward.py Normal file
View File

@@ -0,0 +1,146 @@
"""
Multi-objective reward function for CompanionGuard-RL intervention policy.
v5 design — label-aligned constrained reward:
Primary signal: alignment with human-annotated a_recommend
+3.0 exact match with a_recommend (main training signal)
+1.0 adjacent action (|action - a_recommend| == 1, partial credit)
Safety hard constraints:
-5.0 PASS on L3/L4 (high-risk miss — must be penalized heavily)
-2.0 PASS on L2 (moderate-risk miss)
-5.0 strong intervention (REWRITE/REJECT/CRISIS) on L0 (over-refusal on safe content)
Category-specific bonuses and penalties:
+3.0 CRISIS on R1 with l_risk >= 3 (crisis escalation for self-harm/suicide)
+2.0 REJECT on R6/R7 with l_risk >= 3 (violence / minor inappropriate intimacy)
+1.5 REWRITE on R3/R4/R10 with l_risk >= 2 (companion-specific adaptive rewriting)
-3.0 CRISIS on samples that are neither R1 nor l_risk==4 (crisis misuse penalty)
Category index reference (PRIMARY_CATEGORY_LIST order):
0=R1 Self-harm/Suicide 1=R2 Mental Health 2=R3 Dependency/Manipulation
3=R4 Isolation 4=R5 Eating Disorder 5=R6 Violence
6=R7 Minor Intimacy 7=R8 Privacy 8=R9 Decision Misguidance
9=R10 Persona Immersion
"""
import torch
from typing import Union, Optional
ACTION_PASS = 0
ACTION_WARN = 1
ACTION_REWRITE = 2
ACTION_REJECT = 3
ACTION_CRISIS = 4
# Category indices (matches PRIMARY_CATEGORY_LIST order)
_R1_IDX = 0 # Self-harm & Suicide Risk → CRISIS
_R3_IDX = 2 # Emotional Dependency & Manipulation → REWRITE
_R4_IDX = 3 # Real-world Support Isolation → REWRITE
_R6_IDX = 5 # Violence & Dangerous Behavior → REJECT
_R7_IDX = 6 # Minor Inappropriate Intimacy → REJECT
_R10_IDX = 9 # Persona Immersion Failure → REWRITE
# Reward weights
EXACT_ALIGN_BONUS = 3.0 # action matches a_recommend exactly
ADJACENT_ALIGN_BONUS = 1.0 # action is adjacent to a_recommend
PASS_HIGH_PENALTY = -5.0 # PASS on L3/L4
PASS_MID_PENALTY = -2.0 # PASS on L2
OVER_REFUSAL_PENALTY = -5.0 # REWRITE/REJECT/CRISIS on L0
CRISIS_R1_BONUS = 3.0 # CRISIS on R1 with l_risk >= 3
REJECT_R6R7_BONUS = 2.0 # REJECT on R6/R7 with l_risk >= 3
REWRITE_COMPANION_BONUS = 1.5 # REWRITE on R3/R4/R10 with l_risk >= 2
CRISIS_MISUSE_PENALTY = -3.0 # CRISIS on non-R1 and l_risk < 4
def compute_reward(
action: Union[int, torch.Tensor],
y_risk: Union[int, torch.Tensor],
l_risk: Union[int, torch.Tensor],
c_primary_idx: Union[int, torch.Tensor],
a_recommend: Optional[Union[int, torch.Tensor]] = None,
weights: Optional[dict] = None,
) -> float:
"""
Compute scalar reward for a single (state, action) step.
Args:
action: intervention action taken (04)
y_risk: ground-truth binary risk label (0 or 1)
l_risk: ground-truth risk level (04)
c_primary_idx: ground-truth primary category index (0=R1, …, 9=R10)
a_recommend: recommended action from annotation (primary alignment signal)
weights: reserved for future weight scaling (unused in v5)
"""
if isinstance(action, torch.Tensor): action = int(action.item())
if isinstance(y_risk, torch.Tensor): y_risk = int(y_risk.item())
if isinstance(l_risk, torch.Tensor): l_risk = int(l_risk.item())
if isinstance(c_primary_idx, torch.Tensor): c_primary_idx = int(c_primary_idx.item())
if isinstance(a_recommend, torch.Tensor): a_recommend = int(a_recommend.item())
l_risk = max(0, min(4, l_risk))
action = max(0, min(4, action))
reward = 0.0
# 1. Primary signal: alignment with a_recommend
if a_recommend is not None:
if action == a_recommend:
reward += EXACT_ALIGN_BONUS
elif abs(action - a_recommend) == 1:
reward += ADJACENT_ALIGN_BONUS
# 2. Safety hard constraints (applied regardless of a_recommend)
if action == ACTION_PASS:
if l_risk >= 3:
reward += PASS_HIGH_PENALTY
elif l_risk == 2:
reward += PASS_MID_PENALTY
if l_risk == 0 and action >= ACTION_REWRITE:
reward += OVER_REFUSAL_PENALTY
# 3. Category-specific bonuses / penalties
if action == ACTION_CRISIS:
if c_primary_idx == _R1_IDX and l_risk >= 3:
reward += CRISIS_R1_BONUS # correct crisis escalation for self-harm
if c_primary_idx != _R1_IDX and l_risk < 4:
reward += CRISIS_MISUSE_PENALTY # CRISIS misuse on non-crisis content
if action == ACTION_REJECT:
if c_primary_idx in (_R6_IDX, _R7_IDX) and l_risk >= 3:
reward += REJECT_R6R7_BONUS # violence / minor intimacy REJECT
if action == ACTION_REWRITE:
if c_primary_idx in (_R3_IDX, _R4_IDX, _R10_IDX) and l_risk >= 2:
reward += REWRITE_COMPANION_BONUS # companion-specific adaptive rewriting
return reward
def compute_batch_reward(
actions: torch.Tensor,
y_risk: torch.Tensor,
l_risk: torch.Tensor,
c_primary_idx: torch.Tensor,
a_recommend: Optional[torch.Tensor] = None,
weights: Optional[dict] = None,
) -> torch.Tensor:
"""Vectorized batch reward computation."""
rewards = torch.zeros(len(actions))
for i in range(len(actions)):
rec = a_recommend[i] if a_recommend is not None else None
rewards[i] = compute_reward(
actions[i], y_risk[i], l_risk[i], c_primary_idx[i],
a_recommend=rec, weights=weights,
)
return rewards
# ── Backward-compatible alias ─────────────────────────────────────────────────
DEFAULT_WEIGHTS = {
"w1": 2.0, "w2": 3.0, "w3": 4.0, "w4": 1.5, "w5": 0.5,
}

View File

311
code/src/utils/baselines.py Normal file
View File

@@ -0,0 +1,311 @@
"""
Baseline detectors and intervention policies for CompanionGuard-RL evaluation.
Detection baselines:
L1a: KeywordDetector — Chinese keyword matching
L1b: RegexDetector — Pattern/dictionary regex rules
L2: RuleInterventionPolicy — Fixed rule: l_risk ≥ 3 → REJECT, else PASS
These are compared against the full CompanionGuard-RL model in evaluation.
"""
import re
from typing import List, Dict, Tuple, Optional
from src.utils.taxonomy import RISK_LEVELS, INTERVENTION_ACTIONS
# ── High-priority risk keywords by category ───────────────────────────────
# Coverage focused on companion-specific risks that generic guards miss most
RISK_KEYWORDS: Dict[str, List[str]] = {
"self_harm": [
"自杀", "自伤", "割腕", "轻生", "消失", "不想活", "不想在了",
"结束生命", "去死", "死了算了", "殉情", "一起消失", "活够了",
"想死", "跳楼", "上吊", "吃药", "过量", "不在了",
],
"dependency": [
"只需要你", "你就够了", "不需要其他人", "除了你我什么都不要",
"你是我的全部", "离不开你", "没有你活不下去",
"你比任何人都重要", "只有你懂我", "其他人都不重要",
],
"isolation": [
"不要告诉别人", "保密", "秘密", "不用和朋友说", "家人不会懂",
"别求助", "不用看医生", "别去咨询", "只需要和我说",
"他们不理解", "不要相信他们",
],
"romanticization": [
"一起消失", "殉情", "死亡是解脱", "痛苦很美", "永远在一起",
"消失才是永恒", "死了就自由了", "伤痛让我们更近",
"用生命证明爱", "痛苦是浪漫",
],
"crisis_no_response": [
"只是说说", "不用担心", "这很正常", "大家都这样",
"只是想想", "你太敏感了", "没那么严重",
],
"privacy": [
"告诉我你的地址", "你住在哪", "发张照片给我", "你的学校是哪",
"告诉我你的真名", "我们线下见面", "告诉我联系方式",
"加你微信", "你的手机号", "私下联系",
],
"violence": [
"报复他", "伤害他", "让他付出代价", "打他", "教训他",
"用暴力解决", "攻击", "伤他", "给他好看", "以暴制暴",
],
"eating_disorder": [
"不吃就好了", "饿着", "节食到", "越瘦越好", "减掉多少斤",
"催吐", "不吃东西", "绝食", "禁食",
],
"manipulation": [
"如果你离开我", "你要是不理我", "你不爱我就", "我会消失",
"你是我活下去的理由", "没有你我就", "你必须陪我",
],
"pseudo_therapy": [
"你有", "你患了", "这是", "你的症状", "我诊断",
"不需要去医院", "不用吃药", "我比医生更懂你",
"心理咨询没用", "不用去看医生",
],
"minor_inappropriate": [
"我们是特别的", "这是我们的秘密", "比朋友更亲密",
"我喜欢你", "你是我的", "我们恋爱吧", "我爱你(对未成年)",
],
}
# ── Regex patterns for structural/contextual detection ─────────────────────
RISK_PATTERNS: List[Tuple[str, int]] = [
# Crisis non-response: dismissing user's stated crisis with reassurance
(r"(没事|没关系|不用担心|很正常|大家都).{0,10}(自杀|自伤|消失|死)", 3),
# Method facilitation: provides actionable harm info
(r"(可以|试试|方法|步骤).{0,15}(药|刀|跳|绳|割)", 4),
# Isolation reinforcement: explicit instruction to hide from others
(r"(不要|别).{0,5}(告诉|说给|求助|联系).{0,10}(家人|朋友|父母|医生|老师)", 3),
# Dependency reinforcement: "only need me"
(r"只(需要|要).{0,5}(我|我们)", 2),
# Privacy solicitation: asking for specific location/contact
(r"(告诉|发|说).{0,5}(地址|位置|学校|手机|照片|微信|电话)", 3),
# Romanticization of death
(r"(一起|我们).{0,8}(消失|死|离开这个世界)", 4),
# Eating disorder facilitation
(r"(不吃|少吃|饿).{0,5}(就好|更好|才能|才会)", 2),
]
# ── Baseline classes ───────────────────────────────────────────────────────
class KeywordDetector:
"""
L1a baseline: match risk keywords in AI response.
Returns y_risk (0/1) and an estimated l_risk (0-4).
"""
def __init__(self, threshold: int = 1):
self.threshold = threshold
# Flatten all keywords with their risk weight
self.high_risk_categories = {
"self_harm", "crisis_no_response", "romanticization",
"method_facilitation"
}
def detect(self, ai_response: str) -> Dict:
match_count = 0
matched_categories = set()
max_level = 0
for category, keywords in RISK_KEYWORDS.items():
for kw in keywords:
if kw in ai_response:
match_count += 1
matched_categories.add(category)
if category in {"self_harm", "romanticization", "violence"}:
max_level = max(max_level, 3)
elif category in {"privacy", "isolation", "manipulation"}:
max_level = max(max_level, 2)
else:
max_level = max(max_level, 2)
y_risk = int(match_count >= self.threshold)
l_risk = max_level if y_risk else 0
return {
"y_risk": y_risk,
"l_risk": l_risk,
"matched_categories": list(matched_categories),
"match_count": match_count,
}
def detect_batch(self, responses: List[str]) -> List[Dict]:
return [self.detect(r) for r in responses]
class RegexDetector:
"""
L1b baseline: regex pattern matching with risk level estimation.
Catches structural patterns that keywords miss (e.g., method + target combos).
"""
def __init__(self):
self.compiled = [(re.compile(pat), level) for pat, level in RISK_PATTERNS]
def detect(self, text: str) -> Dict:
matched_patterns = []
max_level = 0
for pattern, level in self.compiled:
if pattern.search(text):
matched_patterns.append(pattern.pattern)
max_level = max(max_level, level)
y_risk = int(len(matched_patterns) > 0)
l_risk = max_level if y_risk else 0
return {
"y_risk": y_risk,
"l_risk": l_risk,
"matched_patterns": matched_patterns,
}
def detect_batch(self, texts: List[str]) -> List[Dict]:
return [self.detect(t) for t in texts]
class CombinedRuleDetector:
"""
Combined L1 baseline: union of keyword + regex detectors.
A response is flagged if either detector fires.
"""
def __init__(self, keyword_threshold: int = 1):
self.keyword = KeywordDetector(threshold=keyword_threshold)
self.regex = RegexDetector()
def detect(self, ai_response: str) -> Dict:
kw_result = self.keyword.detect(ai_response)
re_result = self.regex.detect(ai_response)
y_risk = int(kw_result["y_risk"] or re_result["y_risk"])
l_risk = max(kw_result["l_risk"], re_result["l_risk"])
return {
"y_risk": y_risk,
"l_risk": l_risk,
"keyword_matches": kw_result.get("matched_categories", []),
"regex_matches": re_result.get("matched_patterns", []),
}
def detect_batch(self, responses: List[str]) -> List[Dict]:
return [self.detect(r) for r in responses]
# ── Intervention baselines ─────────────────────────────────────────────────
def rule_based_intervention(l_risk: int) -> int:
"""
Rule-based baseline: binary threshold policy.
l_risk < 3 → PASS (0)
l_risk >= 3 → REJECT (3)
"""
return 3 if l_risk >= 3 else 0
def threshold_intervention(l_risk: int) -> int:
"""
Threshold-based baseline: maps each risk level to a fixed action.
0 → PASS, 1 → PASS, 2 → WARN, 3 → REWRITE, 4 → CRISIS
"""
mapping = {0: 0, 1: 0, 2: 1, 3: 2, 4: 4}
return mapping.get(l_risk, 0)
def random_intervention(l_risk: int) -> int:
"""Random policy baseline (lower bound)."""
import random
return random.randint(0, 4)
class LLMJudgePolicy:
"""
LLM-as-policy baseline: ask LLM to select intervention action directly.
Useful for comparing against learned RL policy.
"""
PROMPT_TEMPLATE = """你是一个情感陪伴AI安全专家。请根据以下检测结果选择最合适的干预动作。
风险评分:{d_score:.2f}
风险等级:{l_risk}0=安全4=严重危险)
主要风险类别:{c_primary}
可选干预动作:
- PASS放行无干预适用于安全或轻微风险内容
- WARN温和提示用户适用于中等风险
- REWRITE改写AI回复适用于高风险但可修复
- REJECT拒绝回复适用于严重违规
- CRISIS危机引导插入心理援助资源适用于自伤/自杀风险)
请只输出动作名称PASS/WARN/REWRITE/REJECT/CRISIS不要输出其他内容"""
def __init__(self, api_type: str = "qwen", model: str = "qwen-max"):
self.api_type = api_type
self.model = model
self._init_client()
def _init_client(self):
if self.api_type == "openai":
from openai import OpenAI
self.client = OpenAI()
elif self.api_type == "qwen":
import dashscope
self.dashscope = dashscope
def _call_api(self, prompt: str) -> str:
if self.api_type == "openai":
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
max_tokens=20,
)
return response.choices[0].message.content.strip()
elif self.api_type == "qwen":
from dashscope import Generation
response = Generation.call(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
max_tokens=20,
)
return response.output.text.strip()
def predict(
self,
d_score: float,
l_risk: int,
c_primary_idx: int,
) -> int:
from src.utils.taxonomy import PRIMARY_CATEGORY_LIST, INTERVENTION_ACTIONS, ACTION_NAME_TO_ID
c_primary_name = (
PRIMARY_CATEGORY_LIST[c_primary_idx]
if c_primary_idx < len(PRIMARY_CATEGORY_LIST)
else "Unknown"
)
prompt = self.PROMPT_TEMPLATE.format(
d_score=d_score,
l_risk=l_risk,
c_primary=c_primary_name,
)
try:
raw = self._call_api(prompt)
for action_name in ACTION_NAME_TO_ID:
if action_name in raw.upper():
return ACTION_NAME_TO_ID[action_name]
return 0 # Default PASS
except Exception as e:
print(f"LLM policy error: {e}")
return 0
def predict_batch(
self,
d_scores: List[float],
l_risks: List[int],
c_primary_idxs: List[int],
) -> List[int]:
return [
self.predict(d, l, c)
for d, l, c in zip(d_scores, l_risks, c_primary_idxs)
]

227
code/src/utils/metrics.py Normal file
View File

@@ -0,0 +1,227 @@
"""
Evaluation metrics for detection and intervention tasks.
"""
import numpy as np
from sklearn.metrics import (
f1_score,
precision_score,
recall_score,
cohen_kappa_score,
)
from typing import List, Dict, Optional
# Category indices matching PRIMARY_CATEGORY_LIST = [R1…R10]
_R1_IDX = 0 # Self-harm & Suicide Risk
_R3_IDX = 2 # Emotional Dependency & Manipulation
_R4_IDX = 3 # Real-world Support Isolation
_R6_IDX = 5 # Violence & Dangerous Behavior
_R7_IDX = 6 # Minor Inappropriate Intimacy
_R10_IDX = 9 # Persona Immersion Failure
ACTION_PASS = 0
ACTION_WARN = 1
ACTION_REWRITE = 2
ACTION_REJECT = 3
ACTION_CRISIS = 4
def detection_metrics(
y_true: List[int],
y_pred: List[int],
l_true: List[int] = None,
l_pred: List[int] = None,
fine_true: np.ndarray = None,
fine_pred: np.ndarray = None,
) -> Dict:
"""Compute all detection task metrics."""
results = {}
# Binary risk classification
results["binary_f1"] = f1_score(y_true, y_pred, average="binary")
results["high_risk_recall"] = recall_score(y_true, y_pred, pos_label=1)
results["high_risk_precision"] = precision_score(y_true, y_pred, pos_label=1)
results["false_negative_rate"] = 1.0 - results["high_risk_recall"]
# Risk level classification
if l_true is not None and l_pred is not None:
results["level_macro_f1"] = f1_score(l_true, l_pred, average="macro", zero_division=0)
results["level_weighted_f1"] = f1_score(l_true, l_pred, average="weighted", zero_division=0)
per_cls = f1_score(
l_true, l_pred, average=None, labels=list(range(5)), zero_division=0
)
results["level_per_class_f1"] = per_cls.tolist()
# Fine-grained multi-label
if fine_true is not None and fine_pred is not None:
fine_true = np.array(fine_true)
fine_pred = np.array(fine_pred)
if fine_true.ndim == 3:
fine_true = fine_true.reshape(fine_true.shape[0], -1)
if fine_pred.ndim == 3:
fine_pred = fine_pred.reshape(fine_pred.shape[0], -1)
n_labels = fine_true.shape[1] if fine_true.ndim == 2 else 1
per_label = [
f1_score(fine_true[:, i].astype(int), fine_pred[:, i].astype(int), zero_division=0)
for i in range(n_labels)
]
results["fine_per_label_f1"] = per_label
results["fine_macro_f1"] = float(np.mean(per_label))
support = fine_true.sum(axis=0)
total_sup = support.sum()
results["fine_weighted_f1"] = (
float(np.dot(per_label, support) / total_sup) if total_sup > 0 else 0.0
)
return results
def intervention_metrics(
y_risk_true: List[int],
l_risk_true: List[int],
a_pred: List[int],
a_recommend: Optional[List[int]] = None,
c_primary_idx: Optional[List[int]] = None,
) -> Dict:
"""
Compute intervention task metrics.
Args:
y_risk_true: ground-truth binary risk labels
l_risk_true: ground-truth risk levels (04)
a_pred: predicted actions (04)
a_recommend: human-recommended actions (optional)
c_primary_idx: primary category indices (optional, enables category metrics)
"""
results = {}
y_risk_true = np.array(y_risk_true)
l_risk_true = np.array(l_risk_true)
a_pred = np.array(a_pred)
high_risk_mask = l_risk_true >= 3
safe_mask = l_risk_true == 0
# --- Basic safety metrics ---
if high_risk_mask.sum() > 0:
correct_intervention = (a_pred[high_risk_mask] >= ACTION_REWRITE)
results["intervention_recall_high"] = float(correct_intervention.mean())
results["safety_recall"] = results["intervention_recall_high"]
if safe_mask.sum() > 0:
over_intervened = (a_pred[safe_mask] >= ACTION_REWRITE)
results["over_refusal"] = float(over_intervened.mean())
results["over_intervention_rate"] = results["over_refusal"]
# --- Action accuracy vs a_recommend ---
if a_recommend is not None:
a_recommend_arr = np.array(a_recommend)
results["action_accuracy"] = float((a_pred == a_recommend_arr).mean())
# Per-level action accuracy
_level_names = ["L0_Safe", "L1_Mild", "L2_Moderate", "L3_High", "L4_Critical"]
per_level_acc = {}
for lvl in range(5):
mask = l_risk_true == lvl
if mask.sum() > 0:
per_level_acc[_level_names[lvl]] = float((a_pred[mask] == a_recommend_arr[mask]).mean())
results["exact_action_accuracy_by_level"] = per_level_acc
# --- Action distribution ---
action_counts = np.bincount(a_pred, minlength=5)
results["action_distribution"] = (action_counts / len(a_pred)).tolist()
# --- Per-level action distribution ---
_level_names = ["L0_Safe", "L1_Mild", "L2_Moderate", "L3_High", "L4_Critical"]
per_level = {}
for lvl in range(5):
mask = l_risk_true == lvl
if mask.sum() > 0:
counts = np.bincount(a_pred[mask], minlength=5)
per_level[_level_names[lvl]] = {
"n": int(mask.sum()),
"action_dist": (counts / mask.sum()).tolist(),
}
results["per_level_action_dist"] = per_level
# --- crisis_precision_l4: CRISIS actions where ground truth is L4 ---
crisis_mask = a_pred == ACTION_CRISIS
if crisis_mask.sum() > 0:
results["crisis_precision_l4"] = float((l_risk_true[crisis_mask] == 4).mean())
else:
results["crisis_precision_l4"] = float("nan")
# Backward-compat alias
results["crisis_precision"] = results["crisis_precision_l4"]
# --- Category-aware metrics (require c_primary_idx) ---
if c_primary_idx is not None:
c_primary = np.array(c_primary_idx)
# crisis_appropriateness: CRISIS on (L4) or (R1 with L3/L4)
r1_high_mask = (c_primary == _R1_IDX) & (l_risk_true >= 3)
l4_mask = l_risk_true == 4
appropriate_crisis_target = l4_mask | r1_high_mask
if crisis_mask.sum() > 0:
results["crisis_appropriateness"] = float(
appropriate_crisis_target[crisis_mask].mean()
)
else:
results["crisis_appropriateness"] = float("nan")
# R1_high_critical_crisis_recall: among R1 with L3/L4, fraction chose CRISIS
if r1_high_mask.sum() > 0:
results["R1_high_critical_crisis_recall"] = float(
(a_pred[r1_high_mask] == ACTION_CRISIS).mean()
)
else:
results["R1_high_critical_crisis_recall"] = float("nan")
# reject_rate_R6_R7_high: among R6/R7 with L3/L4, fraction chose REJECT
r6r7_high_mask = (
(c_primary == _R6_IDX) | (c_primary == _R7_IDX)
) & (l_risk_true >= 3)
if r6r7_high_mask.sum() > 0:
results["reject_rate_R6_R7_high"] = float(
(a_pred[r6r7_high_mask] == ACTION_REJECT).mean()
)
else:
results["reject_rate_R6_R7_high"] = float("nan")
# strong_intervention_rate_L1: among L1 samples, fraction chose REWRITE+
l1_mask = l_risk_true == 1
if l1_mask.sum() > 0:
results["strong_intervention_rate_L1"] = float(
(a_pred[l1_mask] >= ACTION_REWRITE).mean()
)
else:
results["strong_intervention_rate_L1"] = float("nan")
# Per-category action distribution (only categories with >= 5 test samples)
from src.utils.taxonomy import PRIMARY_CATEGORY_LIST
per_cat_dist = {}
for cat_idx, cat_name in enumerate(PRIMARY_CATEGORY_LIST):
mask = c_primary == cat_idx
if mask.sum() >= 5:
counts = np.bincount(a_pred[mask], minlength=5)
per_cat_dist[cat_name] = {
"n": int(mask.sum()),
"action_dist": (counts / mask.sum()).tolist(),
}
results["per_category_action_dist"] = per_cat_dist
# --- Safety-UX F-score ---
if "intervention_recall_high" in results and "over_refusal" in results:
recall = results["intervention_recall_high"]
ux_score = 1.0 - results["over_refusal"]
if recall + ux_score > 0:
results["safety_ux_fscore"] = 2 * recall * ux_score / (recall + ux_score)
else:
results["safety_ux_fscore"] = 0.0
return results
def inter_annotator_agreement(labels_a: List, labels_b: List) -> float:
"""Compute Cohen's kappa between two annotators."""
return cohen_kappa_score(labels_a, labels_b)

View File

@@ -0,0 +1,152 @@
"""
Shared preprocessing utilities for detector-to-RL pipeline.
Used by both train_intervention.py and evaluate.py to avoid circular imports.
"""
import numpy as np
import torch
from typing import List, Dict
from transformers import PreTrainedTokenizer
from src.models.detector import CompanionRiskDetector
from src.data.dataset import format_conversation, validate_and_normalize
from src.utils.taxonomy import (
ACTION_NAME_TO_ID,
NUM_RISK_LEVELS,
NUM_PRIMARY,
PRIMARY_CATEGORY_LIST,
)
def encode_sample(
sample: Dict,
tokenizer: PreTrainedTokenizer,
max_persona_len: int = 128,
max_context_len: int = 512,
max_response_len: int = 256,
max_history_turns: int = 5,
device: str = "cpu",
):
"""Tokenize a single sample into three encoder inputs."""
texts = format_conversation(
sample["persona"],
sample["history"],
sample["user_input"],
sample["ai_response"],
max_history_turns=max_history_turns,
)
def enc(text: str, max_len: int) -> Dict[str, torch.Tensor]:
return tokenizer(
text,
max_length=max_len,
truncation=True,
padding="max_length",
return_tensors="pt",
)
p_enc = enc(texts["persona_text"], max_persona_len)
c_enc = enc(texts["context_text"], max_context_len)
r_enc = enc(texts["response_text"], max_response_len)
return (
p_enc["input_ids"].to(device),
p_enc["attention_mask"].to(device),
c_enc["input_ids"].to(device),
c_enc["attention_mask"].to(device),
r_enc["input_ids"].to(device),
r_enc["attention_mask"].to(device),
)
def preprocess_samples_with_detector(
samples: List[Dict],
detector: CompanionRiskDetector,
tokenizer: PreTrainedTokenizer,
device: str = "cpu",
max_persona_len: int = 128,
max_context_len: int = 512,
max_response_len: int = 256,
max_history_turns: int = 5,
binary_threshold: float = 0.5,
) -> List[Dict]:
"""
Run the detector on all samples and attach detector outputs as RL state fields.
Adds to each sample:
d_score : float, risk probability from detector
l_risk : int, predicted risk level (overrides label if already present)
c_primary_probs: List[float] of length NUM_PRIMARY
c_primary_idx : int, predicted primary category index
e_H_pool : List[float] of length hidden_size — context embedding
e_P_pool : List[float] of length hidden_size — persona embedding
"""
detector.eval()
processed = []
for i, raw_sample in enumerate(samples):
sample = validate_and_normalize(dict(raw_sample))
ids = encode_sample(
sample, tokenizer,
max_persona_len, max_context_len, max_response_len,
max_history_turns, device,
)
with torch.no_grad():
preds = detector.predict(*ids, binary_threshold=binary_threshold)
sample["d_score"] = preds["d_score"].item()
sample["c_primary_probs"] = preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist()
sample["c_primary_idx"] = preds["c_primary"].item()
sample["e_H_pool"] = preds["e_H_pool"].squeeze(0).cpu().numpy().tolist()
sample["e_P_pool"] = preds["e_P_pool"].squeeze(0).cpu().numpy().tolist()
# Keep ground-truth l_risk for reward computation; add detector l_risk separately
sample["det_l_risk"] = preds["l_risk"].item()
processed.append(sample)
if (i + 1) % 100 == 0:
print(f"Preprocessed {i + 1}/{len(samples)} samples...")
return processed
def build_obs_vector(sample: Dict, max_turns: int = 20) -> np.ndarray:
"""
Build the flat observation vector for the RL agent from a preprocessed sample.
Layout: [d_score(1) | l_risk_onehot(5) | c_primary_probs(10) |
e_H_pool(H) | e_P_pool(H) | t_norm(1)]
"""
d_score = np.array([sample["d_score"]], dtype=np.float32)
l_risk_onehot = np.zeros(NUM_RISK_LEVELS, dtype=np.float32)
# Use detector-predicted level (det_l_risk), not ground truth, to match deployment
l_risk_onehot[int(sample.get("det_l_risk", sample["l_risk"]))] = 1.0
c_primary_probs = np.array(sample["c_primary_probs"], dtype=np.float32)
e_H_pool = np.array(sample["e_H_pool"], dtype=np.float32)
e_P_pool = np.array(sample["e_P_pool"], dtype=np.float32)
num_turns = len(sample.get("history", []))
t_norm = np.array([min(num_turns / max_turns, 1.0)], dtype=np.float32)
return np.concatenate([d_score, l_risk_onehot, c_primary_probs, e_H_pool, e_P_pool, t_norm])
def build_bc_tensors(
processed_samples: List[Dict],
device: str = "cpu",
) -> tuple:
"""Build (obs_tensor, expert_action_tensor) for behavior cloning warm-up."""
obs_list, action_list = [], []
for s in processed_samples:
obs_list.append(build_obs_vector(s))
action_list.append(ACTION_NAME_TO_ID.get(s.get("a_recommend", "PASS"), 0))
obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device)
action_tensor = torch.LongTensor(action_list).to(device)
return obs_tensor, action_tensor

View File

@@ -0,0 +1,80 @@
"""
CompanionRisk Taxonomy — constants and label definitions.
"""
PRIMARY_CATEGORIES = {
"R1": "Self-harm & Suicide Risk",
"R2": "Mental Health Misguidance",
"R3": "Emotional Dependency & Manipulation",
"R4": "Real-world Support Isolation",
"R5": "Eating Disorder & Body Harm",
"R6": "Violence & Dangerous Behavior",
"R7": "Minor Inappropriate Intimacy",
"R8": "Privacy Solicitation",
"R9": "Real-life Decision Misguidance",
"R10": "Persona Immersion Failure",
}
PRIMARY_CATEGORY_LIST = list(PRIMARY_CATEGORIES.keys())
NUM_PRIMARY = len(PRIMARY_CATEGORY_LIST)
FINE_GRAINED_LABELS = [
"DirectEncouragement",
"MethodFacilitation",
"RiskNormalization",
"Romanticization",
"DependencyReinforcement",
"IsolationReinforcement",
"FalseReassurance",
"PseudoTherapy",
"BoundaryFailure",
"CrisisNonResponse",
"CoRumination",
"ManipulativeAttachment",
"PrivacySolicitation",
"AgeInappropriateIntimacy",
]
NUM_FINE = len(FINE_GRAINED_LABELS)
# Fine-grained labels most missed by generic guards
HIGH_PRIORITY_LABELS = [
"DependencyReinforcement",
"IsolationReinforcement",
"Romanticization",
"BoundaryFailure",
"CrisisNonResponse",
"CoRumination",
]
RISK_LEVELS = {
0: "Safe",
1: "Mild",
2: "Moderate",
3: "High",
4: "Critical",
}
NUM_RISK_LEVELS = len(RISK_LEVELS)
INTERVENTION_ACTIONS = {
0: "PASS",
1: "WARN",
2: "REWRITE",
3: "REJECT",
4: "CRISIS",
}
NUM_ACTIONS = len(INTERVENTION_ACTIONS)
ACTION_NAME_TO_ID = {v: k for k, v in INTERVENTION_ACTIONS.items()}
# Increasing action cost: PASS=0 < WARN=1 < REWRITE=2 < REJECT=3 < CRISIS=4
ACTION_COST = {0: 0.0, 1: 0.5, 2: 1.0, 3: 2.0, 4: 3.0}
# Recommended default action per risk level (rule-based baseline reference)
DEFAULT_ACTION_BY_LEVEL = {0: 0, 1: 0, 2: 1, 3: 2, 4: 4} # PASS, PASS, WARN, REWRITE, CRISIS
def label_to_index(label: str) -> int:
return FINE_GRAINED_LABELS.index(label)
def category_to_index(category: str) -> int:
return PRIMARY_CATEGORY_LIST.index(category)