Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
167 lines
6.4 KiB
Python
167 lines
6.4 KiB
Python
"""Phase 7 evaluation runner — connects to server via paramiko and runs evaluations."""
|
|
import paramiko
|
|
import warnings
|
|
import time
|
|
import sys
|
|
import json
|
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
HOST = "10.82.3.180"
|
|
PORT = 20083
|
|
USER = "root"
|
|
PASS = "m2dGcwyrhI"
|
|
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
|
|
|
|
|
def ssh_run(client, cmd, timeout=600, print_live=False):
|
|
"""Run a command and return (stdout, stderr, exit_code)."""
|
|
transport = client.get_transport()
|
|
chan = transport.open_session()
|
|
chan.exec_command(cmd)
|
|
|
|
out_parts = []
|
|
err_parts = []
|
|
while True:
|
|
if chan.recv_ready():
|
|
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
|
out_parts.append(chunk)
|
|
if print_live:
|
|
print(chunk, end="", flush=True)
|
|
if chan.recv_stderr_ready():
|
|
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
|
err_parts.append(chunk)
|
|
if chan.exit_status_ready():
|
|
# drain remaining
|
|
while chan.recv_ready():
|
|
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
|
out_parts.append(chunk)
|
|
if print_live:
|
|
print(chunk, end="", flush=True)
|
|
while chan.recv_stderr_ready():
|
|
err_parts.append(chan.recv_stderr(4096).decode("utf-8", errors="replace"))
|
|
break
|
|
time.sleep(0.1)
|
|
|
|
exit_code = chan.recv_exit_status()
|
|
return "".join(out_parts), "".join(err_parts), exit_code
|
|
|
|
|
|
def connect():
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
|
return client
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Phase 7: CompanionGuard-RL Evaluation")
|
|
print("=" * 60)
|
|
|
|
client = connect()
|
|
print("SSH connection established.")
|
|
|
|
# ── Phase 7-C: check source field distribution ──────────────
|
|
print("\n--- Phase 7-C: source field check ---")
|
|
check_script = r"""python3 << 'PYEOF'
|
|
import json
|
|
from collections import Counter
|
|
path = 'data/processed/CompanionRisk-Bench/test.jsonl'
|
|
samples = [json.loads(l) for l in open(path) if l.strip()]
|
|
src_counter = Counter(s.get('source', '(no source field)') for s in samples)
|
|
print("source field distribution:")
|
|
for k, v in sorted(src_counter.items(), key=lambda x: -x[1]):
|
|
print(f" {k}: {v}")
|
|
id_pfx = Counter(s.get('id','?')[:12] for s in samples if not s.get('source'))
|
|
if id_pfx:
|
|
print("id prefix distribution (for samples without source field):")
|
|
for k, v in sorted(id_pfx.items(), key=lambda x: -x[1])[:15]:
|
|
print(f" {k}: {v}")
|
|
risky = sum(int(s.get('y_risk', 0)) for s in samples)
|
|
print(f"Total: {len(samples)}, Risky: {risky}, Safe: {len(samples)-risky}")
|
|
PYEOF"""
|
|
|
|
out, err, code = ssh_run(client, f"cd {PROJ} && {check_script}", timeout=60)
|
|
print(out)
|
|
if err.strip():
|
|
print("STDERR:", err[:500])
|
|
|
|
# ── Phase 7-A: full test set ─────────────────────────────────
|
|
print("\n--- Phase 7-A: running eval --source-filter all ---")
|
|
cmd_all = (
|
|
f"cd {PROJ} && "
|
|
f"python3 scripts/evaluate.py "
|
|
f"--detector-ckpt checkpoints/detector/best.pt "
|
|
f"--config configs/detector_config_server.yaml "
|
|
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
|
f"--source-filter all "
|
|
f"--output experiments/eval_all.json "
|
|
f"2>&1"
|
|
)
|
|
print("Command:", cmd_all[:120], "...")
|
|
out_all, err_all, code_all = ssh_run(client, cmd_all, timeout=600, print_live=True)
|
|
print(f"\n[exit code: {code_all}]")
|
|
if code_all != 0 and err_all.strip():
|
|
print("STDERR:", err_all[-1000:])
|
|
|
|
# ── Phase 7-B: human-annotated subset ───────────────────────
|
|
print("\n--- Phase 7-B: running eval --source-filter human ---")
|
|
cmd_human = (
|
|
f"cd {PROJ} && "
|
|
f"python3 scripts/evaluate.py "
|
|
f"--detector-ckpt checkpoints/detector/best.pt "
|
|
f"--config configs/detector_config_server.yaml "
|
|
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
|
f"--source-filter human "
|
|
f"--output experiments/eval_human_only.json "
|
|
f"2>&1"
|
|
)
|
|
out_human, err_human, code_human = ssh_run(client, cmd_human, timeout=600, print_live=True)
|
|
print(f"\n[exit code: {code_human}]")
|
|
if code_human != 0 and err_human.strip():
|
|
print("STDERR:", err_human[-1000:])
|
|
|
|
# ── Fetch result JSONs ───────────────────────────────────────
|
|
print("\n--- Fetching result JSON files ---")
|
|
sftp = client.open_sftp()
|
|
results = {}
|
|
for tag, remote_path, local_path in [
|
|
("all", f"{PROJ}/experiments/eval_all.json", "eval_all.json"),
|
|
("human", f"{PROJ}/experiments/eval_human_only.json", "eval_human_only.json"),
|
|
]:
|
|
try:
|
|
sftp.get(remote_path, local_path)
|
|
with open(local_path) as f:
|
|
results[tag] = json.load(f)
|
|
print(f" Fetched {tag}: {local_path}")
|
|
except Exception as e:
|
|
print(f" [WARN] Could not fetch {tag} results: {e}")
|
|
sftp.close()
|
|
|
|
# ── Print summary table ──────────────────────────────────────
|
|
print("\n" + "=" * 60)
|
|
print("RESULTS SUMMARY")
|
|
print("=" * 60)
|
|
for tag, data in results.items():
|
|
ours = data.get("ours_detection", {})
|
|
bf1 = ours.get("binary_f1", float("nan"))
|
|
lvlf1 = ours.get("level_macro_f1", float("nan"))
|
|
finef1 = ours.get("fine_macro_f1", float("nan"))
|
|
recall = ours.get("high_risk_recall", float("nan"))
|
|
fnr = ours.get("false_negative_rate", float("nan"))
|
|
n_filt = data.get("meta", {}).get("n_filtered", "?")
|
|
print(f"\n source_filter={tag} (n={n_filt})")
|
|
print(f" binary_f1 = {bf1:.4f}")
|
|
print(f" level_macro_f1 = {lvlf1:.4f}")
|
|
print(f" fine_macro_f1 = {finef1:.4f}")
|
|
print(f" high_risk_recall = {recall:.4f}")
|
|
print(f" false_neg_rate = {fnr:.4f}")
|
|
|
|
client.close()
|
|
print("\n=== Phase 7 done ===")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|