chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
0
code/src/models/__init__.py
Normal file
0
code/src/models/__init__.py
Normal file
218
code/src/models/detector.py
Normal file
218
code/src/models/detector.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Module B: Context-aware Risk Detector.
|
||||
|
||||
Architecture:
|
||||
1. Encode persona, context (history+user_input), response separately
|
||||
2. Fuse via CrossAttention(response, [persona; context])
|
||||
3. Multi-task classification heads:
|
||||
- Binary risk (sigmoid)
|
||||
- Risk level 0-4 (softmax)
|
||||
- Primary category R1-R10 (softmax)
|
||||
- Fine-grained 14-label (sigmoid multi-label)
|
||||
|
||||
Returns e_P_pool and e_H_pool for downstream RL state construction.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
from src.models.encoder import TextEncoder, ContextAwareFusion
|
||||
from src.utils.taxonomy import NUM_PRIMARY, NUM_FINE, NUM_RISK_LEVELS
|
||||
|
||||
|
||||
class CompanionRiskDetector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "hfl/chinese-macbert-large",
|
||||
hidden_size: int = 768,
|
||||
num_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
use_lora: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = TextEncoder(
|
||||
model_name=model_name,
|
||||
hidden_size=hidden_size,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
|
||||
self.fusion = ContextAwareFusion(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Classification heads
|
||||
self.binary_head = nn.Linear(hidden_size, 1)
|
||||
self.level_head = nn.Linear(hidden_size, NUM_RISK_LEVELS)
|
||||
self.primary_head = nn.Linear(hidden_size, NUM_PRIMARY)
|
||||
self.fine_head = nn.Linear(hidden_size, NUM_FINE)
|
||||
|
||||
def _mean_pool(self, hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Mean-pool token representations using attention mask."""
|
||||
m = mask.unsqueeze(-1).float()
|
||||
return (hidden * m).sum(1) / m.sum(1).clamp(min=1e-9)
|
||||
|
||||
def _build_context_padding_mask(
|
||||
self,
|
||||
persona_mask: torch.Tensor,
|
||||
context_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Build boolean padding mask for CrossAttention (True = ignore position)."""
|
||||
return torch.cat([persona_mask == 0, context_mask == 0], dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
persona_input_ids: torch.Tensor,
|
||||
persona_attention_mask: torch.Tensor,
|
||||
context_input_ids: torch.Tensor,
|
||||
context_attention_mask: torch.Tensor,
|
||||
response_input_ids: torch.Tensor,
|
||||
response_attention_mask: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
# Encode all three streams — [B, seq_len, H]
|
||||
persona_h = self.encoder(persona_input_ids, persona_attention_mask)
|
||||
context_h = self.encoder(context_input_ids, context_attention_mask)
|
||||
response_h = self.encoder(response_input_ids, response_attention_mask)
|
||||
|
||||
# Separate pooled representations for RL state
|
||||
e_P_pool = self._mean_pool(persona_h, persona_attention_mask) # [B, H]
|
||||
e_H_pool = self._mean_pool(context_h, context_attention_mask) # [B, H]
|
||||
|
||||
# CrossAttention: response queries [persona; context] as relational context
|
||||
combined_context = torch.cat([persona_h, context_h], dim=1)
|
||||
combined_mask = self._build_context_padding_mask(persona_attention_mask, context_attention_mask)
|
||||
fused = self.fusion(response_h, combined_context, combined_mask)
|
||||
|
||||
# Pool fused representation
|
||||
e_fused = self._mean_pool(fused, response_attention_mask)
|
||||
e_fused = self.dropout(e_fused)
|
||||
|
||||
return {
|
||||
"y_risk": self.binary_head(e_fused).squeeze(-1), # [B]
|
||||
"l_risk": self.level_head(e_fused), # [B, 5]
|
||||
"c_primary": self.primary_head(e_fused), # [B, 10]
|
||||
"c_fine": self.fine_head(e_fused), # [B, 14]
|
||||
"e_fused": e_fused, # [B, H]
|
||||
"e_P_pool": e_P_pool, # [B, H]
|
||||
"e_H_pool": e_H_pool, # [B, H]
|
||||
}
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
logits: Dict[str, torch.Tensor],
|
||||
targets: Dict[str, torch.Tensor],
|
||||
weights: Dict[str, float] = None,
|
||||
fine_pos_weight: Optional[torch.Tensor] = None,
|
||||
fine_risky_only: bool = False,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
fine_pos_weight: shape [NUM_FINE], class-specific positive weights for BCEWithLogitsLoss.
|
||||
Computed from training set: pos_weight[i] = (N - pos_i) / pos_i.
|
||||
Strongly recommended for next training round to handle rare labels
|
||||
like Romanticization/CoRumination (pos_weight ≈ 25.8).
|
||||
fine_risky_only: if True, compute fine-label loss only on y_risk=1 samples.
|
||||
Safe samples have empty c_fine by design; including them
|
||||
in the loss teaches the model to predict all-negative, causing
|
||||
fine_macro_f1 ≈ 0 at evaluation time.
|
||||
"""
|
||||
if weights is None:
|
||||
weights = {"binary": 1.0, "level": 1.0, "primary": 1.0, "fine": 1.0}
|
||||
|
||||
loss_parts = {}
|
||||
|
||||
loss_binary = F.binary_cross_entropy_with_logits(
|
||||
logits["y_risk"], targets["y_risk"].float()
|
||||
)
|
||||
loss_parts["loss_binary"] = loss_binary
|
||||
|
||||
loss_level = F.cross_entropy(logits["l_risk"], targets["l_risk"].long())
|
||||
loss_parts["loss_level"] = loss_level
|
||||
|
||||
# Only compute primary category loss for samples with a valid category
|
||||
# c_primary target is one-hot; samples with c_primary = "None" have all-zero vectors
|
||||
primary_valid_mask = targets["c_primary"].sum(-1) > 0 # [B]
|
||||
if primary_valid_mask.any():
|
||||
primary_targets = targets["c_primary"][primary_valid_mask].argmax(-1)
|
||||
primary_logits = logits["c_primary"][primary_valid_mask]
|
||||
loss_primary = F.cross_entropy(primary_logits, primary_targets)
|
||||
else:
|
||||
loss_primary = torch.tensor(0.0, device=logits["c_primary"].device)
|
||||
loss_parts["loss_primary"] = loss_primary
|
||||
|
||||
# Fine-grained multi-label loss
|
||||
# Optional: restrict to risky samples to avoid teaching all-negative on safe samples
|
||||
if fine_risky_only:
|
||||
risky_mask = targets["y_risk"] > 0.5 # [B]
|
||||
if risky_mask.any():
|
||||
fine_logits_masked = logits["c_fine"][risky_mask]
|
||||
fine_targets_masked = targets["c_fine"][risky_mask].float()
|
||||
loss_fine = F.binary_cross_entropy_with_logits(
|
||||
fine_logits_masked, fine_targets_masked,
|
||||
pos_weight=fine_pos_weight,
|
||||
)
|
||||
else:
|
||||
loss_fine = torch.tensor(0.0, device=logits["c_fine"].device)
|
||||
else:
|
||||
loss_fine = F.binary_cross_entropy_with_logits(
|
||||
logits["c_fine"], targets["c_fine"].float(),
|
||||
pos_weight=fine_pos_weight,
|
||||
)
|
||||
loss_parts["loss_fine"] = loss_fine
|
||||
|
||||
total = (
|
||||
weights["binary"] * loss_binary
|
||||
+ weights["level"] * loss_level
|
||||
+ weights["primary"] * loss_primary
|
||||
+ weights["fine"] * loss_fine
|
||||
)
|
||||
|
||||
return total, loss_parts
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(
|
||||
self,
|
||||
persona_input_ids: torch.Tensor,
|
||||
persona_attention_mask: torch.Tensor,
|
||||
context_input_ids: torch.Tensor,
|
||||
context_attention_mask: torch.Tensor,
|
||||
response_input_ids: torch.Tensor,
|
||||
response_attention_mask: torch.Tensor,
|
||||
binary_threshold: float = 0.5,
|
||||
fine_threshold: float = 0.4,
|
||||
) -> Dict:
|
||||
logits = self.forward(
|
||||
persona_input_ids, persona_attention_mask,
|
||||
context_input_ids, context_attention_mask,
|
||||
response_input_ids, response_attention_mask,
|
||||
)
|
||||
|
||||
d_score = torch.sigmoid(logits["y_risk"])
|
||||
y_risk = (d_score >= binary_threshold).long()
|
||||
l_risk = logits["l_risk"].argmax(-1)
|
||||
c_primary = logits["c_primary"].argmax(-1)
|
||||
c_primary_probs = torch.softmax(logits["c_primary"], dim=-1)
|
||||
c_fine_probs = torch.sigmoid(logits["c_fine"]) # [B, NUM_FINE] continuous scores
|
||||
c_fine = (c_fine_probs >= fine_threshold).float() # [B, NUM_FINE] binary predictions
|
||||
|
||||
return {
|
||||
"y_risk": y_risk,
|
||||
"l_risk": l_risk,
|
||||
"c_primary": c_primary,
|
||||
"c_primary_probs": c_primary_probs,
|
||||
# c_fine: binary predictions (already thresholded) — use this in evaluate.py
|
||||
"c_fine": c_fine,
|
||||
# c_fine_probs: continuous sigmoid scores — use for ranking or custom thresholds
|
||||
"c_fine_probs": c_fine_probs,
|
||||
"d_score": d_score,
|
||||
"e_fused": logits["e_fused"],
|
||||
"e_P_pool": logits["e_P_pool"],
|
||||
"e_H_pool": logits["e_H_pool"],
|
||||
}
|
||||
105
code/src/models/encoder.py
Normal file
105
code/src/models/encoder.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Text encoders for Module B (Context-aware Risk Detector).
|
||||
|
||||
Supports:
|
||||
- MacBERT-large (lightweight Chinese baseline)
|
||||
- Qwen2.5-7B with LoRA (full-scale Chinese)
|
||||
- LLaMA-3.1-8B with LoRA (multilingual)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from peft import get_peft_model, LoraConfig, TaskType
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
"""Shared backbone encoder for persona, context, and response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
hidden_size: int = 768,
|
||||
use_lora: bool = False,
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.05,
|
||||
freeze_base: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.backbone = AutoModel.from_pretrained(model_name)
|
||||
self.actual_hidden = self.backbone.config.hidden_size
|
||||
|
||||
if use_lora:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.FEATURE_EXTRACTION,
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=["q_proj", "v_proj", "query", "value"],
|
||||
)
|
||||
self.backbone = get_peft_model(self.backbone, lora_config)
|
||||
elif freeze_base:
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Project to uniform hidden_size if needed
|
||||
self.proj = (
|
||||
nn.Linear(self.actual_hidden, hidden_size)
|
||||
if self.actual_hidden != hidden_size
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Returns [batch, seq_len, hidden_size]."""
|
||||
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
|
||||
hidden = outputs.last_hidden_state
|
||||
return self.proj(hidden)
|
||||
|
||||
def pool(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Returns mean-pooled [batch, hidden_size]."""
|
||||
hidden = self.forward(input_ids, attention_mask)
|
||||
mask = attention_mask.unsqueeze(-1).float()
|
||||
return (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
|
||||
|
||||
|
||||
class ContextAwareFusion(nn.Module):
|
||||
"""
|
||||
CrossAttention fusion: response as query, [persona; history] as key/value.
|
||||
Captures risk signals in response conditioned on relational context.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int = 768, num_heads: int = 8, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(hidden_size)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size * 4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_size * 4, hidden_size),
|
||||
)
|
||||
self.ffn_norm = nn.LayerNorm(hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
response_hidden: torch.Tensor, # [B, R_len, H]
|
||||
context_hidden: torch.Tensor, # [B, C_len, H] (persona + history concat)
|
||||
context_key_padding_mask: Optional[torch.Tensor] = None, # [B, C_len]
|
||||
) -> torch.Tensor:
|
||||
"""Returns [B, R_len, H] — response enriched with context signals."""
|
||||
attn_out, _ = self.cross_attn(
|
||||
query=response_hidden,
|
||||
key=context_hidden,
|
||||
value=context_hidden,
|
||||
key_padding_mask=context_key_padding_mask,
|
||||
)
|
||||
response_hidden = self.layer_norm(response_hidden + attn_out)
|
||||
ffn_out = self.ffn(response_hidden)
|
||||
return self.ffn_norm(response_hidden + ffn_out)
|
||||
220
code/src/models/intervention_agent.py
Normal file
220
code/src/models/intervention_agent.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Module C: RL Intervention Policy — Actor-Critic network for PPO.
|
||||
|
||||
Observation vector (flat, 2*H+17 dim, e.g. 2065 for detector_hidden=1024):
|
||||
[d_score(1) | l_risk_onehot(5) | c_primary_probs(10) |
|
||||
e_H_pool(H) | e_P_pool(H) | t_norm(1)]
|
||||
|
||||
The network first parses this flat vector through _encode_obs() →
|
||||
StateEncoder → 256-dim latent, then feeds the latent to actor/critic heads.
|
||||
|
||||
Action: {PASS=0, WARN=1, REWRITE=2, REJECT=3, CRISIS=4}
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Dict
|
||||
|
||||
from src.utils.taxonomy import NUM_ACTIONS, NUM_PRIMARY, NUM_RISK_LEVELS
|
||||
|
||||
|
||||
class StateEncoder(nn.Module):
|
||||
"""
|
||||
Encodes the structured state components for the RL policy.
|
||||
|
||||
Input components (passed separately, not as flat vector):
|
||||
d_score : [B, 1]
|
||||
l_risk : [B] — integer level index (0-4), embedded
|
||||
c_primary_probs: [B, NUM_PRIMARY]
|
||||
e_H_pool : [B, detector_hidden]
|
||||
e_P_pool : [B, detector_hidden]
|
||||
t_norm : [B, 1]
|
||||
Output: [B, state_hidden]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector_hidden: int = 1024,
|
||||
level_emb_dim: int = 16,
|
||||
state_hidden: int = 256,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.level_emb = nn.Embedding(NUM_RISK_LEVELS, level_emb_dim)
|
||||
|
||||
# d_score(1) + level_emb(16) + c_primary_probs(10) + e_H(H) + e_P(H) + t_norm(1)
|
||||
state_dim = 1 + level_emb_dim + NUM_PRIMARY + detector_hidden * 2 + 1
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(state_dim, state_hidden * 2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(state_hidden * 2, state_hidden),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out_dim = state_hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
d_score: torch.Tensor, # [B, 1]
|
||||
l_risk: torch.Tensor, # [B] — integer
|
||||
c_primary_probs: torch.Tensor, # [B, NUM_PRIMARY]
|
||||
e_H_pool: torch.Tensor, # [B, H]
|
||||
e_P_pool: torch.Tensor, # [B, H]
|
||||
t_norm: torch.Tensor, # [B, 1]
|
||||
) -> torch.Tensor:
|
||||
level_emb = self.level_emb(l_risk) # [B, 16]
|
||||
state = torch.cat(
|
||||
[d_score, level_emb, c_primary_probs, e_H_pool, e_P_pool, t_norm], dim=-1
|
||||
)
|
||||
return self.mlp(state) # [B, state_hidden]
|
||||
|
||||
|
||||
class InterventionAgent(nn.Module):
|
||||
"""
|
||||
Actor-Critic network for PPO-based intervention policy.
|
||||
|
||||
Actor: π(a | s) = softmax(MLP(encoded_state))
|
||||
Critic: V(s) = MLP(encoded_state)
|
||||
|
||||
All public methods (get_action, evaluate_actions, behavior_clone_loss, forward)
|
||||
accept the raw flat observation vector (2*H+17 dim) — _encode_obs() is called
|
||||
internally to parse and encode it before the actor/critic heads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector_hidden: int = 1024,
|
||||
state_hidden: int = 256,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector_hidden = detector_hidden
|
||||
|
||||
self.state_encoder = StateEncoder(
|
||||
detector_hidden=detector_hidden,
|
||||
state_hidden=state_hidden,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.actor = nn.Sequential(
|
||||
nn.Linear(state_hidden, state_hidden),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(state_hidden, NUM_ACTIONS),
|
||||
)
|
||||
|
||||
self.critic = nn.Sequential(
|
||||
nn.Linear(state_hidden, state_hidden),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(state_hidden, 1),
|
||||
)
|
||||
|
||||
# ── obs parsing ────────────────────────────────────────────────────────
|
||||
|
||||
def _encode_obs(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parse a flat observation vector and encode it through StateEncoder.
|
||||
|
||||
Observation layout (2*H+17 dim):
|
||||
obs[:, 0] → d_score
|
||||
obs[:, 1 : 1+NUM_RISK_LEVELS] → l_risk one-hot (→ argmax for embedding)
|
||||
obs[:, 1+NUM_RISK_LEVELS :
|
||||
1+NUM_RISK_LEVELS+NUM_PRIMARY] → c_primary_probs
|
||||
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY :
|
||||
1+NUM_RISK_LEVELS+NUM_PRIMARY+H] → e_H_pool
|
||||
obs[:, 1+NUM_RISK_LEVELS+NUM_PRIMARY+H :
|
||||
1+NUM_RISK_LEVELS+NUM_PRIMARY+2*H] → e_P_pool
|
||||
obs[:, -1] → t_norm
|
||||
|
||||
Args:
|
||||
obs: [B, 2*H+17] float tensor on any device
|
||||
|
||||
Returns:
|
||||
encoded: [B, state_hidden]
|
||||
"""
|
||||
H = self.detector_hidden
|
||||
ptr = 0
|
||||
|
||||
d_score = obs[:, ptr : ptr + 1]; ptr += 1
|
||||
l_risk_onehot = obs[:, ptr : ptr + NUM_RISK_LEVELS]; ptr += NUM_RISK_LEVELS
|
||||
c_primary = obs[:, ptr : ptr + NUM_PRIMARY]; ptr += NUM_PRIMARY
|
||||
e_H_pool = obs[:, ptr : ptr + H]; ptr += H
|
||||
e_P_pool = obs[:, ptr : ptr + H]; ptr += H
|
||||
t_norm = obs[:, ptr : ptr + 1]
|
||||
|
||||
# Convert one-hot → integer index for the embedding lookup
|
||||
l_risk_idx = l_risk_onehot.argmax(dim=-1) # [B]
|
||||
|
||||
return self.state_encoder(d_score, l_risk_idx, c_primary, e_H_pool, e_P_pool, t_norm)
|
||||
|
||||
# ── public API ─────────────────────────────────────────────────────────
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
obs: raw flat observation [B, 2*H+17]
|
||||
Returns:
|
||||
(action_logits [B, NUM_ACTIONS], state_value [B, 1])
|
||||
"""
|
||||
encoded = self._encode_obs(obs)
|
||||
return self.actor(encoded), self.critic(encoded)
|
||||
|
||||
def encode_state(
|
||||
self,
|
||||
d_score: torch.Tensor,
|
||||
l_risk: torch.Tensor,
|
||||
c_primary_probs: torch.Tensor,
|
||||
e_H_pool: torch.Tensor,
|
||||
e_P_pool: torch.Tensor,
|
||||
t_norm: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Direct StateEncoder call with pre-parsed components (kept for external use)."""
|
||||
return self.state_encoder(d_score, l_risk, c_primary_probs, e_H_pool, e_P_pool, t_norm)
|
||||
|
||||
def get_action(
|
||||
self, obs: torch.Tensor, deterministic: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample (or argmax) an action from the policy.
|
||||
|
||||
Args:
|
||||
obs: [B, 2*H+17] raw observation
|
||||
Returns:
|
||||
(action [B], log_prob [B], entropy [B], value [B])
|
||||
"""
|
||||
logits, value = self.forward(obs)
|
||||
dist = torch.distributions.Categorical(logits=logits)
|
||||
action = logits.argmax(-1) if deterministic else dist.sample()
|
||||
return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
|
||||
|
||||
def evaluate_actions(
|
||||
self, obs: torch.Tensor, actions: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Re-evaluate stored actions for PPO update.
|
||||
|
||||
Args:
|
||||
obs: [B, 2*H+17] raw observation
|
||||
actions: [B] integer action indices
|
||||
Returns:
|
||||
(log_prob [B], entropy [B], value [B])
|
||||
"""
|
||||
logits, value = self.forward(obs)
|
||||
dist = torch.distributions.Categorical(logits=logits)
|
||||
return dist.log_prob(actions), dist.entropy(), value.squeeze(-1)
|
||||
|
||||
def behavior_clone_loss(
|
||||
self, obs: torch.Tensor, expert_actions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Supervised cross-entropy loss for BC warm-up.
|
||||
|
||||
Args:
|
||||
obs: [B, 2*H+17] raw observation
|
||||
expert_actions: [B] integer expert action indices
|
||||
"""
|
||||
logits, _ = self.forward(obs)
|
||||
return F.cross_entropy(logits, expert_actions)
|
||||
Reference in New Issue
Block a user