82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
|
|
"""Check server Python environment and install missing packages."""
|
||
|
|
import paramiko
|
||
|
|
import warnings
|
||
|
|
import time
|
||
|
|
|
||
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||
|
|
|
||
|
|
HOST = "10.82.3.180"
|
||
|
|
PORT = 20083
|
||
|
|
USER = "root"
|
||
|
|
PASS = "m2dGcwyrhI"
|
||
|
|
PROJ = "/root/siton-data-2849d4ce327c4ccfb233ce33868fe7fe/zsy/CompanionGuard-RL"
|
||
|
|
|
||
|
|
|
||
|
|
def ssh_run(client, cmd, timeout=300, print_live=True):
|
||
|
|
transport = client.get_transport()
|
||
|
|
chan = transport.open_session()
|
||
|
|
chan.exec_command(cmd)
|
||
|
|
out_parts = []
|
||
|
|
err_parts = []
|
||
|
|
while True:
|
||
|
|
if chan.recv_ready():
|
||
|
|
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||
|
|
out_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
if chan.recv_stderr_ready():
|
||
|
|
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
||
|
|
err_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print("[ERR]", chunk, end="", flush=True)
|
||
|
|
if chan.exit_status_ready():
|
||
|
|
while chan.recv_ready():
|
||
|
|
chunk = chan.recv(4096).decode("utf-8", errors="replace")
|
||
|
|
out_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print(chunk, end="", flush=True)
|
||
|
|
while chan.recv_stderr_ready():
|
||
|
|
chunk = chan.recv_stderr(4096).decode("utf-8", errors="replace")
|
||
|
|
err_parts.append(chunk)
|
||
|
|
if print_live:
|
||
|
|
print("[ERR]", chunk, end="", flush=True)
|
||
|
|
break
|
||
|
|
time.sleep(0.1)
|
||
|
|
exit_code = chan.recv_exit_status()
|
||
|
|
return "".join(out_parts), "".join(err_parts), exit_code
|
||
|
|
|
||
|
|
|
||
|
|
def connect():
|
||
|
|
client = paramiko.SSHClient()
|
||
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||
|
|
client.connect(HOST, port=PORT, username=USER, password=PASS, timeout=30)
|
||
|
|
return client
|
||
|
|
|
||
|
|
|
||
|
|
client = connect()
|
||
|
|
print("Connected.\n")
|
||
|
|
|
||
|
|
# Find which python to use and what's installed
|
||
|
|
print("=== Python environments ===")
|
||
|
|
ssh_run(client, "which python3 python 2>/dev/null; python3 --version 2>/dev/null; conda env list 2>/dev/null | head -20")
|
||
|
|
|
||
|
|
print("\n=== Check which python was used for training ===")
|
||
|
|
ssh_run(client, f"head -5 {PROJ}/experiments/train_*.log 2>/dev/null | head -20")
|
||
|
|
|
||
|
|
print("\n=== Check if torch is installed ===")
|
||
|
|
ssh_run(client, "python3 -c 'import torch; print(torch.__version__)' 2>&1")
|
||
|
|
|
||
|
|
print("\n=== List installed packages ===")
|
||
|
|
ssh_run(client, "python3 -m pip list 2>/dev/null | grep -E 'yaml|torch|trans|scikit|peft' 2>&1")
|
||
|
|
|
||
|
|
print("\n=== Install missing packages ===")
|
||
|
|
ssh_run(client,
|
||
|
|
"python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pyyaml scikit-learn 2>&1 | tail -10",
|
||
|
|
timeout=120)
|
||
|
|
|
||
|
|
print("\n=== Verify yaml now works ===")
|
||
|
|
ssh_run(client, "python3 -c 'import yaml; print(yaml.__version__)' 2>&1")
|
||
|
|
|
||
|
|
client.close()
|
||
|
|
print("\nDone.")
|