chore: initial commit — unified project repo

Merged code repo (CompanionGuard-RL) into single project-level git.
Reorganized root: docs/, reference/, experiments/, tmp/active|archives/.
Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-14 11:28:42 +08:00
commit bd1f51c496
85 changed files with 20568 additions and 0 deletions

View File

218
code/src/models/detector.py Normal file
View File

@@ -0,0 +1,218 @@
"""
Module B: Context-aware Risk Detector.
Architecture:
1. Encode persona, context (history+user_input), response separately
2. Fuse via CrossAttention(response, [persona; context])
3. Multi-task classification heads:
- Binary risk (sigmoid)
- Risk level 0-4 (softmax)
- Primary category R1-R10 (softmax)
- Fine-grained 14-label (sigmoid multi-label)
Returns e_P_pool and e_H_pool for downstream RL state construction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from src.models.encoder import TextEncoder, ContextAwareFusion
from src.utils.taxonomy import NUM_PRIMARY, NUM_FINE, NUM_RISK_LEVELS
class CompanionRiskDetector(nn.Module):
def __init__(
self,
model_name: str = "hfl/chinese-macbert-large",
hidden_size: int = 768,
num_heads: int = 8,
dropout: float = 0.1,
use_lora: bool = False,
):
super().__init__()
self.encoder = TextEncoder(
model_name=model_name,
hidden_size=hidden_size,
use_lora=use_lora,
)
self.fusion = ContextAwareFusion(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
)
self.dropout = nn.Dropout(dropout)
# Classification heads
self.binary_head = nn.Linear(hidden_size, 1)
self.level_head = nn.Linear(hidden_size, NUM_RISK_LEVELS)
self.primary_head = nn.Linear(hidden_size, NUM_PRIMARY)
self.fine_head = nn.Linear(hidden_size, NUM_FINE)
def _mean_pool(self, hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Mean-pool token representations using attention mask."""
m = mask.unsqueeze(-1).float()
return (hidden * m).sum(1) / m.sum(1).clamp(min=1e-9)
def _build_context_padding_mask(
self,
persona_mask: torch.Tensor,
context_mask: torch.Tensor,
) -> torch.Tensor:
"""Build boolean padding mask for CrossAttention (True = ignore position)."""
return torch.cat([persona_mask == 0, context_mask == 0], dim=1)
def forward(
self,
persona_input_ids: torch.Tensor,
persona_attention_mask: torch.Tensor,
context_input_ids: torch.Tensor,
context_attention_mask: torch.Tensor,
response_input_ids: torch.Tensor,
response_attention_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
# Encode all three streams — [B, seq_len, H]
persona_h = self.encoder(persona_input_ids, persona_attention_mask)
context_h = self.encoder(context_input_ids, context_attention_mask)
response_h = self.encoder(response_input_ids, response_attention_mask)
# Separate pooled representations for RL state
e_P_pool = self._mean_pool(persona_h, persona_attention_mask) # [B, H]
e_H_pool = self._mean_pool(context_h, context_attention_mask) # [B, H]
# CrossAttention: response queries [persona; context] as relational context
combined_context = torch.cat([persona_h, context_h], dim=1)
combined_mask = self._build_context_padding_mask(persona_attention_mask, context_attention_mask)
fused = self.fusion(response_h, combined_context, combined_mask)
# Pool fused representation
e_fused = self._mean_pool(fused, response_attention_mask)
e_fused = self.dropout(e_fused)
return {
"y_risk": self.binary_head(e_fused).squeeze(-1), # [B]
"l_risk": self.level_head(e_fused), # [B, 5]
"c_primary": self.primary_head(e_fused), # [B, 10]
"c_fine": self.fine_head(e_fused), # [B, 14]
"e_fused": e_fused, # [B, H]
"e_P_pool": e_P_pool, # [B, H]
"e_H_pool": e_H_pool, # [B, H]
}
def compute_loss(
self,
logits: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
weights: Dict[str, float] = None,
fine_pos_weight: Optional[torch.Tensor] = None,
fine_risky_only: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
fine_pos_weight: shape [NUM_FINE], class-specific positive weights for BCEWithLogitsLoss.
Computed from training set: pos_weight[i] = (N - pos_i) / pos_i.
Strongly recommended for next training round to handle rare labels
like Romanticization/CoRumination (pos_weight ≈ 25.8).
fine_risky_only: if True, compute fine-label loss only on y_risk=1 samples.
Safe samples have empty c_fine by design; including them
in the loss teaches the model to predict all-negative, causing
fine_macro_f1 ≈ 0 at evaluation time.
"""
if weights is None:
weights = {"binary": 1.0, "level": 1.0, "primary": 1.0, "fine": 1.0}
loss_parts = {}
loss_binary = F.binary_cross_entropy_with_logits(
logits["y_risk"], targets["y_risk"].float()
)
loss_parts["loss_binary"] = loss_binary
loss_level = F.cross_entropy(logits["l_risk"], targets["l_risk"].long())
loss_parts["loss_level"] = loss_level
# Only compute primary category loss for samples with a valid category
# c_primary target is one-hot; samples with c_primary = "None" have all-zero vectors
primary_valid_mask = targets["c_primary"].sum(-1) > 0 # [B]
if primary_valid_mask.any():
primary_targets = targets["c_primary"][primary_valid_mask].argmax(-1)
primary_logits = logits["c_primary"][primary_valid_mask]
loss_primary = F.cross_entropy(primary_logits, primary_targets)
else:
loss_primary = torch.tensor(0.0, device=logits["c_primary"].device)
loss_parts["loss_primary"] = loss_primary
# Fine-grained multi-label loss
# Optional: restrict to risky samples to avoid teaching all-negative on safe samples
if fine_risky_only:
risky_mask = targets["y_risk"] > 0.5 # [B]
if risky_mask.any():
fine_logits_masked = logits["c_fine"][risky_mask]
fine_targets_masked = targets["c_fine"][risky_mask].float()
loss_fine = F.binary_cross_entropy_with_logits(
fine_logits_masked, fine_targets_masked,
pos_weight=fine_pos_weight,
)
else:
loss_fine = torch.tensor(0.0, device=logits["c_fine"].device)
else:
loss_fine = F.binary_cross_entropy_with_logits(
logits["c_fine"], targets["c_fine"].float(),
pos_weight=fine_pos_weight,
)
loss_parts["loss_fine"] = loss_fine
total = (
weights["binary"] * loss_binary
+ weights["level"] * loss_level
+ weights["primary"] * loss_primary
+ weights["fine"] * loss_fine
)
return total, loss_parts
@torch.no_grad()
def predict(
self,
persona_input_ids: torch.Tensor,
persona_attention_mask: torch.Tensor,
context_input_ids: torch.Tensor,
context_attention_mask: torch.Tensor,
response_input_ids: torch.Tensor,
response_attention_mask: torch.Tensor,
binary_threshold: float = 0.5,
fine_threshold: float = 0.4,
) -> Dict:
logits = self.forward(
persona_input_ids, persona_attention_mask,
context_input_ids, context_attention_mask,
response_input_ids, response_attention_mask,
)
d_score = torch.sigmoid(logits["y_risk"])
y_risk = (d_score >= binary_threshold).long()
l_risk = logits["l_risk"].argmax(-1)
c_primary = logits["c_primary"].argmax(-1)
c_primary_probs = torch.softmax(logits["c_primary"], dim=-1)
c_fine_probs = torch.sigmoid(logits["c_fine"]) # [B, NUM_FINE] continuous scores
c_fine = (c_fine_probs >= fine_threshold).float() # [B, NUM_FINE] binary predictions
return {
"y_risk": y_risk,
"l_risk": l_risk,
"c_primary": c_primary,
"c_primary_probs": c_primary_probs,
# c_fine: binary predictions (already thresholded) — use this in evaluate.py
"c_fine": c_fine,
# c_fine_probs: continuous sigmoid scores — use for ranking or custom thresholds
"c_fine_probs": c_fine_probs,
"d_score": d_score,
"e_fused": logits["e_fused"],
"e_P_pool": logits["e_P_pool"],
"e_H_pool": logits["e_H_pool"],
}

105
code/src/models/encoder.py Normal file
View File

@@ -0,0 +1,105 @@
"""
Text encoders for Module B (Context-aware Risk Detector).
Supports:
- MacBERT-large (lightweight Chinese baseline)
- Qwen2.5-7B with LoRA (full-scale Chinese)
- LLaMA-3.1-8B with LoRA (multilingual)
"""
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
from typing import Optional
class TextEncoder(nn.Module):
"""Shared backbone encoder for persona, context, and response."""
def __init__(
self,
model_name: str,
hidden_size: int = 768,
use_lora: bool = False,
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
freeze_base: bool = False,
):
super().__init__()
self.backbone = AutoModel.from_pretrained(model_name)
self.actual_hidden = self.backbone.config.hidden_size
if use_lora:
lora_config = LoraConfig(
task_type=TaskType.FEATURE_EXTRACTION,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=["q_proj", "v_proj", "query", "value"],
)
self.backbone = get_peft_model(self.backbone, lora_config)
elif freeze_base:
for param in self.backbone.parameters():
param.requires_grad = False
# Project to uniform hidden_size if needed
self.proj = (
nn.Linear(self.actual_hidden, hidden_size)
if self.actual_hidden != hidden_size
else nn.Identity()
)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Returns [batch, seq_len, hidden_size]."""
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
hidden = outputs.last_hidden_state
return self.proj(hidden)
def pool(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Returns mean-pooled [batch, hidden_size]."""
hidden = self.forward(input_ids, attention_mask)
mask = attention_mask.unsqueeze(-1).float()
return (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
class ContextAwareFusion(nn.Module):
"""
CrossAttention fusion: response as query, [persona; history] as key/value.
Captures risk signals in response conditioned on relational context.
"""
def __init__(self, hidden_size: int = 768, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.layer_norm = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size * 4, hidden_size),
)
self.ffn_norm = nn.LayerNorm(hidden_size)
def forward(
self,
response_hidden: torch.Tensor, # [B, R_len, H]
context_hidden: torch.Tensor, # [B, C_len, H] (persona + history concat)
context_key_padding_mask: Optional[torch.Tensor] = None, # [B, C_len]
) -> torch.Tensor:
"""Returns [B, R_len, H] — response enriched with context signals."""
attn_out, _ = self.cross_attn(
query=response_hidden,
key=context_hidden,
value=context_hidden,
key_padding_mask=context_key_padding_mask,
)
response_hidden = self.layer_norm(response_hidden + attn_out)
ffn_out = self.ffn(response_hidden)
return self.ffn_norm(response_hidden + ffn_out)

View File

@@ -0,0 +1,220 @@
"""
Module C: RL Intervention Policy — Actor-Critic network for PPO.
Observation vector (flat, 2*H+17 dim, e.g. 2065 for detector_hidden=1024):
[d_score(1) | l_risk_onehot(5) | c_primary_probs(10) |
e_H_pool(H) | e_P_pool(H) | t_norm(1)]
The network first parses this flat vector through _encode_obs() →
StateEncoder → 256-dim latent, then feeds the latent to actor/critic heads.
Action: {PASS=0, WARN=1, REWRITE=2, REJECT=3, CRISIS=4}
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict
from src.utils.taxonomy import NUM_ACTIONS, NUM_PRIMARY, NUM_RISK_LEVELS
class StateEncoder(nn.Module):
"""
Encodes the structured state components for the RL policy.
Input components (passed separately, not as flat vector):
d_score : [B, 1]
l_risk : [B] — integer level index (0-4), embedded
c_primary_probs: [B, NUM_PRIMARY]
e_H_pool : [B, detector_hidden]
e_P_pool : [B, detector_hidden]
t_norm : [B, 1]
Output: [B, state_hidden]
"""
def __init__(
self,
detector_hidden: int = 1024,
level_emb_dim: int = 16,
state_hidden: int = 256,
dropout: float = 0.1,
):
super().__init__()
self.level_emb = nn.Embedding(NUM_RISK_LEVELS, level_emb_dim)
# d_score(1) + level_emb(16) + c_primary_probs(10) + e_H(H) + e_P(H) + t_norm(1)
state_dim = 1 + level_emb_dim + NUM_PRIMARY + detector_hidden * 2 + 1
self.mlp = nn.Sequential(
nn.Linear(state_dim, state_hidden * 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden * 2, state_hidden),
nn.ReLU(),
)
self.out_dim = state_hidden
def forward(
self,
d_score: torch.Tensor, # [B, 1]
l_risk: torch.Tensor, # [B] — integer
c_primary_probs: torch.Tensor, # [B, NUM_PRIMARY]
e_H_pool: torch.Tensor, # [B, H]
e_P_pool: torch.Tensor, # [B, H]
t_norm: torch.Tensor, # [B, 1]
) -> torch.Tensor:
level_emb = self.level_emb(l_risk) # [B, 16]
state = torch.cat(
[d_score, level_emb, c_primary_probs, e_H_pool, e_P_pool, t_norm], dim=-1
)
return self.mlp(state) # [B, state_hidden]
class InterventionAgent(nn.Module):
"""
Actor-Critic network for PPO-based intervention policy.
Actor: π(a | s) = softmax(MLP(encoded_state))
Critic: V(s) = MLP(encoded_state)
All public methods (get_action, evaluate_actions, behavior_clone_loss, forward)
accept the raw flat observation vector (2*H+17 dim) — _encode_obs() is called
internally to parse and encode it before the actor/critic heads.
"""
def __init__(
self,
detector_hidden: int = 1024,
state_hidden: int = 256,
dropout: float = 0.1,
):
super().__init__()
self.detector_hidden = detector_hidden
self.state_encoder = StateEncoder(
detector_hidden=detector_hidden,
state_hidden=state_hidden,
dropout=dropout,
)
self.actor = nn.Sequential(
nn.Linear(state_hidden, state_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden, NUM_ACTIONS),
)
self.critic = nn.Sequential(
nn.Linear(state_hidden, state_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(state_hidden, 1),
)
# ── obs parsing ────────────────────────────────────────────────────────
def _encode_obs(self, obs: torch.Tensor) -> torch.Tensor:
"""
Parse a flat observation vector and encode it through StateEncoder.
Observation layout (2*H+17 dim):
obs[:, 0] → d_score
obs[:, 1 : 1+NUM_RISK_LEVELS] → l_risk one-hot (→ argmax for embedding)
obs[:, 1+NUM_RISK_LEVELS :
1+NUM_RISK_LEVELS+NUM_PRIMARY] → c_primary_probs
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY :
1+NUM_RISK_LEVELS+NUM_PRIMARY+H] → e_H_pool
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY+H :
1+NUM_RISK_LEVELS+NUM_PRIMARY+2*H] → e_P_pool
obs[:, -1] → t_norm
Args:
obs: [B, 2*H+17] float tensor on any device
Returns:
encoded: [B, state_hidden]
"""
H = self.detector_hidden
ptr = 0
d_score = obs[:, ptr : ptr + 1]; ptr += 1
l_risk_onehot = obs[:, ptr : ptr + NUM_RISK_LEVELS]; ptr += NUM_RISK_LEVELS
c_primary = obs[:, ptr : ptr + NUM_PRIMARY]; ptr += NUM_PRIMARY
e_H_pool = obs[:, ptr : ptr + H]; ptr += H
e_P_pool = obs[:, ptr : ptr + H]; ptr += H
t_norm = obs[:, ptr : ptr + 1]
# Convert one-hot → integer index for the embedding lookup
l_risk_idx = l_risk_onehot.argmax(dim=-1) # [B]
return self.state_encoder(d_score, l_risk_idx, c_primary, e_H_pool, e_P_pool, t_norm)
# ── public API ─────────────────────────────────────────────────────────
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
obs: raw flat observation [B, 2*H+17]
Returns:
(action_logits [B, NUM_ACTIONS], state_value [B, 1])
"""
encoded = self._encode_obs(obs)
return self.actor(encoded), self.critic(encoded)
def encode_state(
self,
d_score: torch.Tensor,
l_risk: torch.Tensor,
c_primary_probs: torch.Tensor,
e_H_pool: torch.Tensor,
e_P_pool: torch.Tensor,
t_norm: torch.Tensor,
) -> torch.Tensor:
"""Direct StateEncoder call with pre-parsed components (kept for external use)."""
return self.state_encoder(d_score, l_risk, c_primary_probs, e_H_pool, e_P_pool, t_norm)
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample (or argmax) an action from the policy.
Args:
obs: [B, 2*H+17] raw observation
Returns:
(action [B], log_prob [B], entropy [B], value [B])
"""
logits, value = self.forward(obs)
dist = torch.distributions.Categorical(logits=logits)
action = logits.argmax(-1) if deterministic else dist.sample()
return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
def evaluate_actions(
self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Re-evaluate stored actions for PPO update.
Args:
obs: [B, 2*H+17] raw observation
actions: [B] integer action indices
Returns:
(log_prob [B], entropy [B], value [B])
"""
logits, value = self.forward(obs)
dist = torch.distributions.Categorical(logits=logits)
return dist.log_prob(actions), dist.entropy(), value.squeeze(-1)
def behavior_clone_loss(
self, obs: torch.Tensor, expert_actions: torch.Tensor
) -> torch.Tensor:
"""
Supervised cross-entropy loss for BC warm-up.
Args:
obs: [B, 2*H+17] raw observation
expert_actions: [B] integer expert action indices
"""
logits, _ = self.forward(obs)
return F.cross_entropy(logits, expert_actions)