chore: initial commit — unified project repo
Merged code repo (CompanionGuard-RL) into single project-level git. Reorganized root: docs/, reference/, experiments/, tmp/active|archives/. Gitignored: data/, checkpoints/, .venv, experiment logs, tmp/archives. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
15
tmp/active/copy_deps.sh
Normal file
15
tmp/active/copy_deps.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
MA=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/old-road-code/envs/multimodal_affect/lib/python3.10/site-packages
|
||||
DST=/opt/conda/envs/dlapo-py310-cu128/lib/python3.10/site-packages
|
||||
|
||||
for pkg in platformdirs sentry_sdk docker pynvml setproctitle; do
|
||||
if [ -d "$MA/$pkg" ]; then
|
||||
cp -r "$MA/$pkg" "$DST/" && echo "ok: $pkg"
|
||||
fi
|
||||
for dist in "$MA/${pkg}"-*.dist-info; do
|
||||
[ -d "$dist" ] && cp -r "$dist" "$DST/" && echo "ok: $(basename $dist)"
|
||||
done
|
||||
done
|
||||
|
||||
echo "--- testing wandb import ---"
|
||||
/opt/conda/envs/dlapo-py310-cu128/bin/python -c "import wandb; print('wandb ok')" 2>&1
|
||||
166
tmp/active/run_phase7.py
Normal file
166
tmp/active/run_phase7.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Phase 7 evaluation runner — connects to server via paramiko and runs evaluations."""
|
||||
import paramiko
|
||||
import warnings
|
||||
import time
|
||||
import sys
|
||||
import json
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
HOST = "10.82.3.180"
|
||||
PORT = 20083
|
||||
USER = "root"
|
||||
PASS = "m2dGcwyrhI"
|
||||
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
|
||||
|
||||
def ssh_run(client, cmd, timeout=600, print_live=False):
|
||||
"""Run a command and return (stdout, stderr, exit_code)."""
|
||||
transport = client.get_transport()
|
||||
chan = transport.open_session()
|
||||
chan.exec_command(cmd)
|
||||
|
||||
out_parts = []
|
||||
err_parts = []
|
||||
while True:
|
||||
if chan.recv_ready():
|
||||
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
if chan.recv_stderr_ready():
|
||||
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
||||
err_parts.append(chunk)
|
||||
if chan.exit_status_ready():
|
||||
# drain remaining
|
||||
while chan.recv_ready():
|
||||
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
while chan.recv_stderr_ready():
|
||||
err_parts.append(chan.recv_stderr(4096).decode("utf-8", errors="replace"))
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
exit_code = chan.recv_exit_status()
|
||||
return "".join(out_parts), "".join(err_parts), exit_code
|
||||
|
||||
|
||||
def connect():
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
||||
return client
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Phase 7: CompanionGuard-RL Evaluation")
|
||||
print("=" * 60)
|
||||
|
||||
client = connect()
|
||||
print("SSH connection established.")
|
||||
|
||||
# ── Phase 7-C: check source field distribution ──────────────
|
||||
print("\n--- Phase 7-C: source field check ---")
|
||||
check_script = r"""python3 << 'PYEOF'
|
||||
import json
|
||||
from collections import Counter
|
||||
path = 'data/processed/CompanionRisk-Bench/test.jsonl'
|
||||
samples = [json.loads(l) for l in open(path) if l.strip()]
|
||||
src_counter = Counter(s.get('source', '(no source field)') for s in samples)
|
||||
print("source field distribution:")
|
||||
for k, v in sorted(src_counter.items(), key=lambda x: -x[1]):
|
||||
print(f" {k}: {v}")
|
||||
id_pfx = Counter(s.get('id','?')[:12] for s in samples if not s.get('source'))
|
||||
if id_pfx:
|
||||
print("id prefix distribution (for samples without source field):")
|
||||
for k, v in sorted(id_pfx.items(), key=lambda x: -x[1])[:15]:
|
||||
print(f" {k}: {v}")
|
||||
risky = sum(int(s.get('y_risk', 0)) for s in samples)
|
||||
print(f"Total: {len(samples)}, Risky: {risky}, Safe: {len(samples)-risky}")
|
||||
PYEOF"""
|
||||
|
||||
out, err, code = ssh_run(client, f"cd {PROJ} && {check_script}", timeout=60)
|
||||
print(out)
|
||||
if err.strip():
|
||||
print("STDERR:", err[:500])
|
||||
|
||||
# ── Phase 7-A: full test set ─────────────────────────────────
|
||||
print("\n--- Phase 7-A: running eval --source-filter all ---")
|
||||
cmd_all = (
|
||||
f"cd {PROJ} && "
|
||||
f"python3 scripts/evaluate.py "
|
||||
f"--detector-ckpt checkpoints/detector/best.pt "
|
||||
f"--config configs/detector_config_server.yaml "
|
||||
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||||
f"--source-filter all "
|
||||
f"--output experiments/eval_all.json "
|
||||
f"2>&1"
|
||||
)
|
||||
print("Command:", cmd_all[:120], "...")
|
||||
out_all, err_all, code_all = ssh_run(client, cmd_all, timeout=600, print_live=True)
|
||||
print(f"\n[exit code: {code_all}]")
|
||||
if code_all != 0 and err_all.strip():
|
||||
print("STDERR:", err_all[-1000:])
|
||||
|
||||
# ── Phase 7-B: human-annotated subset ───────────────────────
|
||||
print("\n--- Phase 7-B: running eval --source-filter human ---")
|
||||
cmd_human = (
|
||||
f"cd {PROJ} && "
|
||||
f"python3 scripts/evaluate.py "
|
||||
f"--detector-ckpt checkpoints/detector/best.pt "
|
||||
f"--config configs/detector_config_server.yaml "
|
||||
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||||
f"--source-filter human "
|
||||
f"--output experiments/eval_human_only.json "
|
||||
f"2>&1"
|
||||
)
|
||||
out_human, err_human, code_human = ssh_run(client, cmd_human, timeout=600, print_live=True)
|
||||
print(f"\n[exit code: {code_human}]")
|
||||
if code_human != 0 and err_human.strip():
|
||||
print("STDERR:", err_human[-1000:])
|
||||
|
||||
# ── Fetch result JSONs ───────────────────────────────────────
|
||||
print("\n--- Fetching result JSON files ---")
|
||||
sftp = client.open_sftp()
|
||||
results = {}
|
||||
for tag, remote_path, local_path in [
|
||||
("all", f"{PROJ}/experiments/eval_all.json", "eval_all.json"),
|
||||
("human", f"{PROJ}/experiments/eval_human_only.json", "eval_human_only.json"),
|
||||
]:
|
||||
try:
|
||||
sftp.get(remote_path, local_path)
|
||||
with open(local_path) as f:
|
||||
results[tag] = json.load(f)
|
||||
print(f" Fetched {tag}: {local_path}")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Could not fetch {tag} results: {e}")
|
||||
sftp.close()
|
||||
|
||||
# ── Print summary table ──────────────────────────────────────
|
||||
print("\n" + "=" * 60)
|
||||
print("RESULTS SUMMARY")
|
||||
print("=" * 60)
|
||||
for tag, data in results.items():
|
||||
ours = data.get("ours_detection", {})
|
||||
bf1 = ours.get("binary_f1", float("nan"))
|
||||
lvlf1 = ours.get("level_macro_f1", float("nan"))
|
||||
finef1 = ours.get("fine_macro_f1", float("nan"))
|
||||
recall = ours.get("high_risk_recall", float("nan"))
|
||||
fnr = ours.get("false_negative_rate", float("nan"))
|
||||
n_filt = data.get("meta", {}).get("n_filtered", "?")
|
||||
print(f"\n source_filter={tag} (n={n_filt})")
|
||||
print(f" binary_f1 = {bf1:.4f}")
|
||||
print(f" level_macro_f1 = {lvlf1:.4f}")
|
||||
print(f" fine_macro_f1 = {finef1:.4f}")
|
||||
print(f" high_risk_recall = {recall:.4f}")
|
||||
print(f" false_neg_rate = {fnr:.4f}")
|
||||
|
||||
client.close()
|
||||
print("\n=== Phase 7 done ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
172
tmp/active/run_phase7_conda.py
Normal file
172
tmp/active/run_phase7_conda.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Run Phase 7 evaluation using the conda env that has torch/transformers."""
|
||||
import paramiko
|
||||
import warnings
|
||||
import time
|
||||
import json
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
HOST = "10.82.3.180"
|
||||
PORT = 20083
|
||||
USER = "root"
|
||||
PASS = "m2dGcwyrhI"
|
||||
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
CONDA_PY = "/opt/conda/envs/dlapo-py310-cu128/bin/python3"
|
||||
|
||||
|
||||
def ssh_run(client, cmd, timeout=600, print_live=True):
|
||||
transport = client.get_transport()
|
||||
chan = transport.open_session()
|
||||
chan.exec_command(cmd)
|
||||
out_parts = []
|
||||
err_parts = []
|
||||
deadline = time.time() + timeout
|
||||
while True:
|
||||
if chan.recv_ready():
|
||||
chunk = chan.recv(8192).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
if chan.recv_stderr_ready():
|
||||
chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace")
|
||||
err_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
if chan.exit_status_ready():
|
||||
# drain
|
||||
while chan.recv_ready():
|
||||
chunk = chan.recv(8192).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
while chan.recv_stderr_ready():
|
||||
chunk = chan.recv_stderr(8192).decode("utf-8", errors="replace")
|
||||
err_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
break
|
||||
if time.time() > deadline:
|
||||
print("\n[TIMEOUT]")
|
||||
break
|
||||
time.sleep(0.2)
|
||||
exit_code = chan.recv_exit_status()
|
||||
return "".join(out_parts), "".join(err_parts), exit_code
|
||||
|
||||
|
||||
def connect():
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
||||
return client
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Phase 7: CompanionGuard-RL Evaluation (conda env)")
|
||||
print("=" * 60)
|
||||
|
||||
client = connect()
|
||||
print("SSH connection established.\n")
|
||||
|
||||
# Verify conda env has required packages
|
||||
print("--- Verifying conda env packages ---")
|
||||
ssh_run(client,
|
||||
f"{CONDA_PY} -c 'import torch, yaml, transformers, sklearn; "
|
||||
f"print(\"torch:\", torch.__version__, \"| cuda:\", torch.cuda.device_count(), "
|
||||
f"\"| yaml ok | transformers:\", transformers.__version__)' 2>&1",
|
||||
timeout=30)
|
||||
|
||||
# Phase 7-C: source field check
|
||||
print("\n--- Phase 7-C: source field distribution ---")
|
||||
src_check = (
|
||||
f"cd {PROJ} && {CONDA_PY} -c \""
|
||||
"import json; from collections import Counter; "
|
||||
"samples=[json.loads(l) for l in open('data/processed/CompanionRisk-Bench/test.jsonl') if l.strip()]; "
|
||||
"src=Counter(s.get('source','(none)') for s in samples); "
|
||||
"[print(f' {k}: {v}') for k,v in sorted(src.items(),key=lambda x:-x[1])]; "
|
||||
"risky=sum(int(s.get('y_risk',0)) for s in samples); "
|
||||
"print(f'Total: {len(samples)}, Risky: {risky}, Safe: {len(samples)-risky}')"
|
||||
"\""
|
||||
)
|
||||
ssh_run(client, src_check, timeout=30)
|
||||
|
||||
# Phase 7-A: full test set evaluation
|
||||
print("\n--- Phase 7-A: eval --source-filter all ---")
|
||||
cmd_all = (
|
||||
f"cd {PROJ} && PYTHONPATH={PROJ} "
|
||||
f"{CONDA_PY} scripts/evaluate.py "
|
||||
f"--detector-ckpt checkpoints/detector/best.pt "
|
||||
f"--config configs/detector_config_server.yaml "
|
||||
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||||
f"--source-filter all "
|
||||
f"--output experiments/eval_all.json "
|
||||
f"2>&1"
|
||||
)
|
||||
out_all, _, code_all = ssh_run(client, cmd_all, timeout=600)
|
||||
print(f"\n[Phase 7-A exit code: {code_all}]")
|
||||
|
||||
# Phase 7-B: human subset evaluation
|
||||
print("\n--- Phase 7-B: eval --source-filter human ---")
|
||||
cmd_human = (
|
||||
f"cd {PROJ} && PYTHONPATH={PROJ} "
|
||||
f"{CONDA_PY} scripts/evaluate.py "
|
||||
f"--detector-ckpt checkpoints/detector/best.pt "
|
||||
f"--config configs/detector_config_server.yaml "
|
||||
f"--test-data data/processed/CompanionRisk-Bench/test.jsonl "
|
||||
f"--source-filter human "
|
||||
f"--output experiments/eval_human_only.json "
|
||||
f"2>&1"
|
||||
)
|
||||
out_human, _, code_human = ssh_run(client, cmd_human, timeout=600)
|
||||
print(f"\n[Phase 7-B exit code: {code_human}]")
|
||||
|
||||
# Fetch result JSONs
|
||||
print("\n--- Fetching result JSONs ---")
|
||||
sftp = client.open_sftp()
|
||||
results = {}
|
||||
for tag, remote, local in [
|
||||
("all", f"{PROJ}/experiments/eval_all.json", "eval_all.json"),
|
||||
("human", f"{PROJ}/experiments/eval_human_only.json", "eval_human_only.json"),
|
||||
]:
|
||||
try:
|
||||
sftp.get(remote, local)
|
||||
with open(local, encoding="utf-8") as f:
|
||||
results[tag] = json.load(f)
|
||||
print(f" Fetched {tag}: {local}")
|
||||
except Exception as e:
|
||||
print(f" [WARN] {tag}: {e}")
|
||||
sftp.close()
|
||||
|
||||
# Summary table
|
||||
print("\n" + "=" * 60)
|
||||
print("KEY METRICS SUMMARY (Ours: CompanionRiskDetector)")
|
||||
print("=" * 60)
|
||||
print(f" {'Metric':<25} {'all (n=605)':>15} {'human (n=119)':>15}")
|
||||
print(f" {'-'*55}")
|
||||
metric_keys = [
|
||||
("binary_f1", "binary_f1"),
|
||||
("level_macro_f1", "level_macro_f1"),
|
||||
("fine_macro_f1", "fine_macro_f1"),
|
||||
("high_risk_recall", "high_risk_recall"),
|
||||
("false_negative_rate","false_negative_rate"),
|
||||
("accuracy", "accuracy"),
|
||||
]
|
||||
for label, key in metric_keys:
|
||||
val_all = results.get("all", {}).get("ours_detection", {}).get(key, float("nan"))
|
||||
val_human = results.get("human", {}).get("ours_detection", {}).get(key, float("nan"))
|
||||
try:
|
||||
a = f"{val_all:.4f}"
|
||||
except Exception:
|
||||
a = str(val_all)
|
||||
try:
|
||||
h = f"{val_human:.4f}"
|
||||
except Exception:
|
||||
h = str(val_human)
|
||||
print(f" {label:<25} {a:>15} {h:>15}")
|
||||
|
||||
client.close()
|
||||
print("\n=== Phase 7 done ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
tmp/active/run_phase7_setup.py
Normal file
81
tmp/active/run_phase7_setup.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Check server Python environment and install missing packages."""
|
||||
import paramiko
|
||||
import warnings
|
||||
import time
|
||||
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
HOST = "10.82.3.180"
|
||||
PORT = 20083
|
||||
USER = "root"
|
||||
PASS = "m2dGcwyrhI"
|
||||
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||||
|
||||
|
||||
def ssh_run(client, cmd, timeout=300, print_live=True):
|
||||
transport = client.get_transport()
|
||||
chan = transport.open_session()
|
||||
chan.exec_command(cmd)
|
||||
out_parts = []
|
||||
err_parts = []
|
||||
while True:
|
||||
if chan.recv_ready():
|
||||
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
if chan.recv_stderr_ready():
|
||||
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
||||
err_parts.append(chunk)
|
||||
if print_live:
|
||||
print("[ERR]", chunk, end="", flush=True)
|
||||
if chan.exit_status_ready():
|
||||
while chan.recv_ready():
|
||||
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||||
out_parts.append(chunk)
|
||||
if print_live:
|
||||
print(chunk, end="", flush=True)
|
||||
while chan.recv_stderr_ready():
|
||||
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
||||
err_parts.append(chunk)
|
||||
if print_live:
|
||||
print("[ERR]", chunk, end="", flush=True)
|
||||
break
|
||||
time.sleep(0.1)
|
||||
exit_code = chan.recv_exit_status()
|
||||
return "".join(out_parts), "".join(err_parts), exit_code
|
||||
|
||||
|
||||
def connect():
|
||||
client = paramiko.SSHClient()
|
||||
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
||||
return client
|
||||
|
||||
|
||||
client = connect()
|
||||
print("Connected.\n")
|
||||
|
||||
# Find which python to use and what's installed
|
||||
print("=== Python environments ===")
|
||||
ssh_run(client, "which python3 python 2>/dev/null; python3 --version 2>/dev/null; conda env list 2>/dev/null | head -20")
|
||||
|
||||
print("\n=== Check which python was used for training ===")
|
||||
ssh_run(client, f"head -5 {PROJ}/experiments/train_*.log 2>/dev/null | head -20")
|
||||
|
||||
print("\n=== Check if torch is installed ===")
|
||||
ssh_run(client, "python3 -c 'import torch; print(torch.__version__)' 2>&1")
|
||||
|
||||
print("\n=== List installed packages ===")
|
||||
ssh_run(client, "python3 -m pip list 2>/dev/null | grep -E 'yaml|torch|trans|scikit|peft' 2>&1")
|
||||
|
||||
print("\n=== Install missing packages ===")
|
||||
ssh_run(client,
|
||||
"python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyyaml scikit-learn 2>&1 | tail -10",
|
||||
timeout=120)
|
||||
|
||||
print("\n=== Verify yaml now works ===")
|
||||
ssh_run(client, "python3 -c 'import yaml; print(yaml.__version__)' 2>&1")
|
||||
|
||||
client.close()
|
||||
print("\nDone.")
|
||||
9
tmp/active/start_train_v3.sh
Normal file
9
tmp/active/start_train_v3.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
cd /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL
|
||||
ACCEL=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/old-road-code/envs/multimodal_affect/bin/accelerate
|
||||
LOG=experiments/train_v3_$(date +%Y%m%d_%H%M%S).log
|
||||
nohup $ACCEL launch --num_processes=4 --mixed_precision=bf16 \
|
||||
scripts/train_detector.py \
|
||||
--config configs/detector_config_server.yaml \
|
||||
> $LOG 2>&1 &
|
||||
echo "PID: $! LOG: $LOG"
|
||||
6
tmp/active/start_v3.sh
Normal file
6
tmp/active/start_v3.sh
Normal file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
cd /root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL
|
||||
ACCEL=/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/old-road-code/envs/multimodal_affect/bin/accelerate
|
||||
LOG=experiments/train_v3_$(date +%Y%m%d_%H%M%S).log
|
||||
nohup $ACCEL launch --num_processes=4 --mixed_precision=bf16 scripts/train_detector.py --config configs/detector_config_server.yaml > $LOG 2>&1 &
|
||||
echo "PID: $! LOG: $LOG"
|
||||
19
tmp/active/train_v5.sh
Normal file
19
tmp/active/train_v5.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
PROJ=/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/my-reasearch/companionguard-rl
|
||||
PYTHON=/root/siton-data-740d234e02d749f08fe5347b0c74c49f/zsy/env/dlapo-py310-cu128/bin/python
|
||||
cd $PROJ
|
||||
export PYTHONPATH=$PROJ
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
mkdir -p experiments checkpoints/intervention
|
||||
LOG=$PROJ/experiments/train_intervention_v5_$(date +%Y%m%d_%H%M%S).log
|
||||
echo "Starting Module C v5 training (GPU 1, direct python)"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
echo "Log: $LOG"
|
||||
# Run directly without accelerate launcher to avoid CUDA init issues
|
||||
$PYTHON scripts/train_intervention.py \
|
||||
--config configs/intervention_config.yaml \
|
||||
--train-data data/processed/CompanionRisk-Bench/train.jsonl \
|
||||
>> $LOG 2>&1
|
||||
EXIT_CODE=$?
|
||||
echo "v5 training done, exit=$EXIT_CODE" >> $LOG
|
||||
echo "Training finished with exit=$EXIT_CODE, log=$LOG"
|
||||
Reference in New Issue
Block a user