"""Run Phase 7 evaluation using the conda env that has torch/transformers.""" import paramiko import warnings import time import json warnings.filterwarnings("ignore", category=DeprecationWarning) HOST = "10.82.3.180" PORT = 20083 USER = "root" PASS = "m2dGcwyrhI" PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL" CONDA_PY = "/opt/conda/envs/dlapo-py310-cu128/bin/python3" def ssh_run(client, cmd, timeout=600, print_live=True): transport = client.get_transport() chan = transport.open_session() chan.exec_command(cmd) out_parts = [] err_parts = [] deadline = time.time() + timeout while True: if chan.recv_ready(): chunk = chan.recv(8192).decode("utf-8", errors="replace") out_parts.append(chunk) if print_live: print(chunk, end="", flush=True) if chan.recv_stderr_ready(): chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace") err_parts.append(chunk) if print_live: print(chunk, end="", flush=True) if chan.exit_status_ready(): # drain while chan.recv_ready(): chunk = chan.recv(8192).decode("utf-8", errors="replace") out_parts.append(chunk) if print_live: print(chunk, end="", flush=True) while chan.recv_stderr_ready(): chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace") err_parts.append(chunk) if print_live: print(chunk, end="", flush=True) break if time.time() > deadline: print("\n[TIMEOUT]") break time.sleep(0.2) exit_code = chan.recv_exit_status() return "".join(out_parts), "".join(err_parts), exit_code def connect(): client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30) return client def main(): print("=" * 60) print("Phase 7: CompanionGuard-RL Evaluation (conda env)") print("=" * 60) client = connect() print("SSH connection established.\n") # Verify conda env has required packages print("--- Verifying conda env packages ---") ssh_run(client, f"{CONDA_PY} -c 'import torch, yaml, transformers, sklearn; " f"print(\"torch:\", torch.__version__, \"| cuda:\", torch.cuda.device_count(), " f"\"| yaml ok | transformers:\", transformers.__version__)' 2>&1", timeout=30) # Phase 7-C: source field check print("\n--- Phase 7-C: source field distribution ---") src_check = ( f"cd {PROJ} && {CONDA_PY} -c \"" "import json; from collections import Counter; " "samples=[json.loads(l) for l in open('data/processed/CompanionRisk-Bench/test.jsonl') if l.strip()]; " "src=Counter(s.get('source','(none)') for s in samples); " "[print(f' {k}: {v}') for k,v in sorted(src.items(),key=lambda x:-x[1])]; " "risky=sum(int(s.get('y_risk',0)) for s in samples); " "print(f'Total: {len(samples)}, Risky: {risky}, Safe: {len(samples)-risky}')" "\"" ) ssh_run(client, src_check, timeout=30) # Phase 7-A: full test set evaluation print("\n--- Phase 7-A: eval --source-filter all ---") cmd_all = ( f"cd {PROJ} && PYTHONPATH={PROJ} " f"{CONDA_PY} scripts/evaluate.py " f"--detector-ckpt checkpoints/detector/best.pt " f"--config configs/detector_config_server.yaml " f"--test-data data/processed/CompanionRisk-Bench/test.jsonl " f"--source-filter all " f"--output experiments/eval_all.json " f"2>&1" ) out_all, _, code_all = ssh_run(client, cmd_all, timeout=600) print(f"\n[Phase 7-A exit code: {code_all}]") # Phase 7-B: human subset evaluation print("\n--- Phase 7-B: eval --source-filter human ---") cmd_human = ( f"cd {PROJ} && PYTHONPATH={PROJ} " f"{CONDA_PY} scripts/evaluate.py " f"--detector-ckpt checkpoints/detector/best.pt " f"--config configs/detector_config_server.yaml " f"--test-data data/processed/CompanionRisk-Bench/test.jsonl " f"--source-filter human " f"--output experiments/eval_human_only.json " f"2>&1" ) out_human, _, code_human = ssh_run(client, cmd_human, timeout=600) print(f"\n[Phase 7-B exit code: {code_human}]") # Fetch result JSONs print("\n--- Fetching result JSONs ---") sftp = client.open_sftp() results = {} for tag, remote, local in [ ("all", f"{PROJ}/experiments/eval_all.json", "eval_all.json"), ("human", f"{PROJ}/experiments/eval_human_only.json", "eval_human_only.json"), ]: try: sftp.get(remote, local) with open(local, encoding="utf-8") as f: results[tag] = json.load(f) print(f" Fetched {tag}: {local}") except Exception as e: print(f" [WARN] {tag}: {e}") sftp.close() # Summary table print("\n" + "=" * 60) print("KEY METRICS SUMMARY (Ours: CompanionRiskDetector)") print("=" * 60) print(f" {'Metric':<25} {'all (n=605)':>15} {'human (n=119)':>15}") print(f" {'-'*55}") metric_keys = [ ("binary_f1", "binary_f1"), ("level_macro_f1", "level_macro_f1"), ("fine_macro_f1", "fine_macro_f1"), ("high_risk_recall", "high_risk_recall"), ("false_negative_rate","false_negative_rate"), ("accuracy", "accuracy"), ] for label, key in metric_keys: val_all = results.get("all", {}).get("ours_detection", {}).get(key, float("nan")) val_human = results.get("human", {}).get("ours_detection", {}).get(key, float("nan")) try: a = f"{val_all:.4f}" except Exception: a = str(val_all) try: h = f"{val_human:.4f}" except Exception: h = str(val_human) print(f" {label:<25} {a:>15} {h:>15}") client.close() print("\n=== Phase 7 done ===") if __name__ == "__main__": main()