""" Step 4: Train Module C — RL Intervention Policy (PPO). Two-stage training: Stage 1: Behavior cloning warm-up from a_recommend labels Stage 2: PPO fine-tuning with multi-objective reward Usage: python scripts/train_intervention.py --config configs/intervention_config.yaml """ import argparse import yaml import torch import numpy as np import wandb from pathlib import Path from src.data.dataset import load_jsonl from src.models.detector import CompanionRiskDetector from src.models.intervention_agent import InterventionAgent from src.rl.companion_env import CompanionEnv from src.rl.ppo_trainer import PPOTrainer from src.utils.taxonomy import ( ACTION_NAME_TO_ID, NUM_RISK_LEVELS, NUM_PRIMARY, category_to_index, ) from transformers import AutoTokenizer def preprocess_samples_with_detector(samples, detector, tokenizer, cfg, device): """Run detector on all samples to extract state vectors for RL env.""" from src.data.dataset import format_conversation processed = [] detector.eval() for sample in samples: texts = format_conversation( sample["persona"], sample["history"], sample["user_input"], sample["ai_response"], ) def enc(text, max_len): return tokenizer( text, max_length=max_len, truncation=True, padding="max_length", return_tensors="pt", ) p_enc = enc(texts["persona_text"], 128) c_enc = enc(texts["context_text"], 512) r_enc = enc(texts["response_text"], 256) with torch.no_grad(): preds = detector.predict( p_enc["input_ids"].to(device), p_enc["attention_mask"].to(device), c_enc["input_ids"].to(device), c_enc["attention_mask"].to(device), r_enc["input_ids"].to(device), r_enc["attention_mask"].to(device), ) # Build persona/history pool embeddings (reuse e_fused as approximation) e_fused = preds["e_fused"].squeeze(0).cpu().numpy() processed.append({ **sample, "d_score": preds["d_score"].item(), "l_risk": preds["l_risk"].item(), "c_primary_probs": preds["c_primary_probs"].squeeze(0).cpu().numpy().tolist(), "c_primary_idx": preds["c_primary"].item(), "e_H_pool": e_fused.tolist(), "e_P_pool": e_fused.tolist(), "a_recommend": sample.get("a_recommend", "PASS"), }) return processed def build_bc_tensors(processed_samples, obs_dim, device): """Build observation and expert action tensors for behavior cloning.""" obs_list, action_list = [], [] for s in processed_samples: d_score = np.array([s["d_score"]], dtype=np.float32) l_risk_oh = np.zeros(NUM_RISK_LEVELS, dtype=np.float32) l_risk_oh[int(s["l_risk"])] = 1.0 c_probs = np.array(s["c_primary_probs"], dtype=np.float32) e_H = np.array(s["e_H_pool"], dtype=np.float32) e_P = np.array(s["e_P_pool"], dtype=np.float32) t_norm = np.array([len(s.get("history", [])) / 20.0], dtype=np.float32) obs = np.concatenate([d_score, l_risk_oh, c_probs, e_H, e_P, t_norm]) obs_list.append(obs) action_list.append(ACTION_NAME_TO_ID.get(s["a_recommend"], 0)) obs_tensor = torch.FloatTensor(np.stack(obs_list)).to(device) action_tensor = torch.LongTensor(action_list).to(device) return obs_tensor, action_tensor def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", default="configs/intervention_config.yaml") parser.add_argument("--train-data", default="data/processed/train.jsonl") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") if cfg["logging"]["use_wandb"]: wandb.init( project=cfg["logging"]["project"], name=cfg["logging"]["run_name"], config=cfg, ) # Load detector tokenizer = AutoTokenizer.from_pretrained(cfg["detector"]["model_name"]) detector = CompanionRiskDetector( model_name=cfg["detector"]["model_name"], hidden_size=cfg["detector"]["hidden_size"], ).to(device) detector.load_state_dict(torch.load(cfg["detector"]["checkpoint"], map_location=device)) detector.eval() print("Detector loaded.") # Load and preprocess training data raw_samples = load_jsonl(args.train_data) print(f"Preprocessing {len(raw_samples)} samples with detector...") processed = preprocess_samples_with_detector(raw_samples, detector, tokenizer, cfg, device) detector_hidden = cfg["detector"]["hidden_size"] obs_dim = 1 + NUM_RISK_LEVELS + NUM_PRIMARY + detector_hidden * 2 + 1 # Build RL agent agent = InterventionAgent( detector_hidden=detector_hidden, state_hidden=cfg["agent"]["state_hidden"], dropout=cfg["agent"]["dropout"], ) trainer = PPOTrainer( agent=agent, obs_dim=obs_dim, lr=cfg["ppo"]["lr"], clip_eps=cfg["ppo"]["clip_eps"], entropy_coef=cfg["ppo"]["entropy_coef"], value_coef=cfg["ppo"]["value_coef"], max_grad_norm=cfg["ppo"]["max_grad_norm"], gamma=cfg["ppo"]["gamma"], gae_lambda=cfg["ppo"]["gae_lambda"], n_epochs=cfg["ppo"]["n_epochs"], batch_size=cfg["ppo"]["batch_size"], buffer_size=cfg["ppo"]["n_rollout_steps"], device=device, use_wandb=cfg["logging"]["use_wandb"], ) # Stage 1: Behavior cloning warm-up if cfg["behavior_cloning"]["enabled"]: print("Stage 1: Behavior cloning warm-up...") obs_tensor, action_tensor = build_bc_tensors(processed, obs_dim, device) trainer.behavior_cloning_warmup( obs_tensor, action_tensor, n_epochs=cfg["behavior_cloning"]["epochs"], lr=cfg["behavior_cloning"]["lr"], ) # Stage 2: PPO fine-tuning print("Stage 2: PPO fine-tuning...") env = CompanionEnv( samples=processed, detector_hidden=detector_hidden, reward_weights=cfg["reward"], max_turns=cfg["environment"]["max_turns"], ) Path(cfg["output"]["checkpoint_dir"]).mkdir(parents=True, exist_ok=True) trainer.train( env=env, total_timesteps=cfg["ppo"]["total_timesteps"], n_rollout_steps=cfg["ppo"]["n_rollout_steps"], checkpoint_dir=cfg["output"]["checkpoint_dir"], save_interval=cfg["output"]["save_interval"], ) torch.save(agent.state_dict(), f"{cfg['output']['checkpoint_dir']}/final.pt") print("Training complete.") if __name__ == "__main__": main()