173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
|
|
"""Run Phase 7 evaluation using the conda env that has torch/transformers."""
|
||
|
|
import paramiko
|
||
|
|
import warnings
|
||
|
|
import time
|
||
|
|
import json
|
||
|
|
|
||
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||
|
|
|
||
|
|
HOST = "10.82.3.180"
|
||
|
|
PORT = 20083
|
||
|
|
USER = "root"
|
||
|
|
PASS = "m2dGcwyrhI"
|
||
|
|
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||
|
|
CONDA_PY = "/opt/conda/envs/dlapo-py310-cu128/bin/python3"
|
||
|
|
|
||
|
|
|
||
|
|
def ssh_run(client, cmd, timeout=600, print_live=True):
|
||
|
|
transport = client.get_transport()
|
||
|
|
chan = transport.open_session()
|
||
|
|
chan.exec_command(cmd)
|
||
|
|
out_parts = []
|
||
|
|
err_parts = []
|
||
|
|
deadline = time.time() + timeout
|
||
|
|
while True:
|
||
|
|
if chan.recv_ready():
|
||
|
|
chunk = chan.recv(8192).decode("utf-8", errors="replace")
|
||
|
|
out_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
if chan.recv_stderr_ready():
|
||
|
|
chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace")
|
||
|
|
err_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
if chan.exit_status_ready():
|
||
|
|
# drain
|
||
|
|
while chan.recv_ready():
|
||
|
|
chunk = chan.recv(8192).decode("utf-8", errors="replace")
|
||
|
|
out_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
while chan.recv_stderr_ready():
|
||
|
|
chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace")
|
||
|
|
err_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
break
|
||
|
|
if time.time() > deadline:
|
||
|
|
print("\n[TIMEOUT]")
|
||
|
|
break
|
||
|
|
time.sleep(0.2)
|
||
|
|
exit_code = chan.recv_exit_status()
|
||
|
|
return "".join(out_parts), "".join(err_parts), exit_code
|
||
|
|
|
||
|
|
|
||
|
|
def connect():
|
||
|
|
client = paramiko.SSHClient()
|
||
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||
|
|
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
||
|
|
return client
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
print("=" * 60)
|
||
|
|
print("Phase 7: CompanionGuard-RL Evaluation (conda env)")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
client = connect()
|
||
|
|
print("SSH connection established.\n")
|
||
|
|
|
||
|
|
# Verify conda env has required packages
|
||
|
|
print("--- Verifying conda env packages ---")
|
||
|
|
ssh_run(client,
|
||
|
|
f"{CONDA_PY} -c 'import torch, yaml, transformers, sklearn; "
|
||
|
|
f"print(\"torch:\", torch.__version__, \"| cuda:\", torch.cuda.device_count(), "
|
||
|
|
f"\"| yaml ok | transformers:\", transformers.__version__)' 2>&1",
|
||
|
|
timeout=30)
|
||
|
|
|
||
|
|
# Phase 7-C: source field check
|
||
|
|
print("\n--- Phase 7-C: source field distribution ---")
|
||
|
|
src_check = (
|
||
|
|
f"cd {PROJ} && {CONDA_PY} -c \""
|
||
|
|
"import json; from collections import Counter; "
|
||
|
|
"samples=[json.loads(l) for l in open('data/processed/CompanionRisk-Bench/test.jsonl') if l.strip()]; "
|
||
|
|
"src=Counter(s.get('source','(none)') for s in samples); "
|
||
|
|
"[print(f' {k}: {v}') for k,v in sorted(src.items(),key=lambda x:-x[1])]; "
|
||
|
|
"risky=sum(int(s.get('y_risk',0)) for s in samples); "
|
||
|
|
"print(f'Total: {len(samples)}, Risky: {risky}, Safe: {len(samples)-risky}')"
|
||
|
|
"\""
|
||
|
|
)
|
||
|
|
ssh_run(client, src_check, timeout=30)
|
||
|
|
|
||
|
|
# Phase 7-A: full test set evaluation
|
||
|
|
print("\n--- Phase 7-A: eval --source-filter all ---")
|
||
|
|
cmd_all = (
|
||
|
|
f"cd {PROJ} && PYTHONPATH={PROJ} "
|
||
|
|
f"{CONDA_PY} scripts/evaluate.py "
|
||
|
|
f"--detector-ckpt checkpoints/detector/best.pt "
|
||
|
|
f"--config configs/detector_config_server.yaml "
|
||
|
|
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||
|
|
f"--source-filter all "
|
||
|
|
f"--output experiments/eval_all.json "
|
||
|
|
f"2>&1"
|
||
|
|
)
|
||
|
|
out_all, _, code_all = ssh_run(client, cmd_all, timeout=600)
|
||
|
|
print(f"\n[Phase 7-A exit code: {code_all}]")
|
||
|
|
|
||
|
|
# Phase 7-B: human subset evaluation
|
||
|
|
print("\n--- Phase 7-B: eval --source-filter human ---")
|
||
|
|
cmd_human = (
|
||
|
|
f"cd {PROJ} && PYTHONPATH={PROJ} "
|
||
|
|
f"{CONDA_PY} scripts/evaluate.py "
|
||
|
|
f"--detector-ckpt checkpoints/detector/best.pt "
|
||
|
|
f"--config configs/detector_config_server.yaml "
|
||
|
|
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||
|
|
f"--source-filter human "
|
||
|
|
f"--output experiments/eval_human_only.json "
|
||
|
|
f"2>&1"
|
||
|
|
)
|
||
|
|
out_human, _, code_human = ssh_run(client, cmd_human, timeout=600)
|
||
|
|
print(f"\n[Phase 7-B exit code: {code_human}]")
|
||
|
|
|
||
|
|
# Fetch result JSONs
|
||
|
|
print("\n--- Fetching result JSONs ---")
|
||
|
|
sftp = client.open_sftp()
|
||
|
|
results = {}
|
||
|
|
for tag, remote, local in [
|
||
|
|
("all", f"{PROJ}/experiments/eval_all.json", "eval_all.json"),
|
||
|
|
("human", f"{PROJ}/experiments/eval_human_only.json", "eval_human_only.json"),
|
||
|
|
]:
|
||
|
|
try:
|
||
|
|
sftp.get(remote, local)
|
||
|
|
with open(local, encoding="utf-8") as f:
|
||
|
|
results[tag] = json.load(f)
|
||
|
|
print(f" Fetched {tag}: {local}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" [WARN] {tag}: {e}")
|
||
|
|
sftp.close()
|
||
|
|
|
||
|
|
# Summary table
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
print("KEY METRICS SUMMARY (Ours: CompanionRiskDetector)")
|
||
|
|
print("=" * 60)
|
||
|
|
print(f" {'Metric':<25} {'all (n=605)':>15} {'human (n=119)':>15}")
|
||
|
|
print(f" {'-'*55}")
|
||
|
|
metric_keys = [
|
||
|
|
("binary_f1", "binary_f1"),
|
||
|
|
("level_macro_f1", "level_macro_f1"),
|
||
|
|
("fine_macro_f1", "fine_macro_f1"),
|
||
|
|
("high_risk_recall", "high_risk_recall"),
|
||
|
|
("false_negative_rate","false_negative_rate"),
|
||
|
|
("accuracy", "accuracy"),
|
||
|
|
]
|
||
|
|
for label, key in metric_keys:
|
||
|
|
val_all = results.get("all", {}).get("ours_detection", {}).get(key, float("nan"))
|
||
|
|
val_human = results.get("human", {}).get("ours_detection", {}).get(key, float("nan"))
|
||
|
|
try:
|
||
|
|
a = f"{val_all:.4f}"
|
||
|
|
except Exception:
|
||
|
|
a = str(val_all)
|
||
|
|
try:
|
||
|
|
h = f"{val_human:.4f}"
|
||
|
|
except Exception:
|
||
|
|
h = str(val_human)
|
||
|
|
print(f" {label:<25} {a:>15} {h:>15}")
|
||
|
|
|
||
|
|
client.close()
|
||
|
|
print("\n=== Phase 7 done ===")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|