import asyncio import subprocess import sys from datetime import datetime from pathlib import Path from typing import List, Optional from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from pydantic import BaseModel, Field from app.schemas.common import Response from app.services.auth_service import get_current_user from app.services.prediction import ( AcidSpeedModel, TensionModel, QualityPredictionModel, AcidConsumptionModel, _load_cal, _save_cal, _get_phys, append_sample, get_sample_stats, fit_acid_phys_params, fit_quality_phys_params, reload_onnx, ) router = APIRouter() TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3", "rinse","leveler","s2_roller","outlet"] _BACKEND_DIR = Path(__file__).parent.parent.parent # ───────────────────────────────────────────────────────────────────────────── # Prediction request schemas # ───────────────────────────────────────────────────────────────────────────── class AcidSpeedRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) steel_grade: str acid_conc_list: List[float] acid_temp_list: List[float] scale_weight: Optional[float] = 8.5 target_pi: Optional[float] = 95.0 class TensionRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) yield_strength: float = Field(..., gt=0) tension_coef: Optional[float] = 0.25 steel_grade: Optional[str] = "_default" class QualityRequest(BaseModel): thickness: float = Field(..., gt=0) avg_speed: float = Field(..., gt=0) acid_conc_avg: float = Field(..., gt=0) acid_temp_avg: float = Field(..., gt=0) scale_weight: Optional[float] = 8.5 fe_conc_avg: Optional[float] = 60.0 steel_grade: Optional[str] = "_default" class ConsumptionRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) coil_weight_kg: float = Field(..., gt=0) has_regen_station: Optional[bool] = True fe_conc_avg: Optional[float] = 60.0 # ───────────────────────────────────────────────────────────────────────────── # Calibration request schemas # ───────────────────────────────────────────────────────────────────────────── class AcidCalibRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) steel_grade: str acid_conc_list: List[float] acid_temp_list: List[float] scale_weight: Optional[float] = 8.5 actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min") actual_quality_ok: bool = Field(..., description="该速度下质量是否合格") note: Optional[str] = None class TensionCalibRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) yield_strength: float = Field(..., gt=0) tension_coef: Optional[float] = 0.25 steel_grade: Optional[str] = "_default" zone: str = Field(..., description="测量位置,如 s1_roller") measured_kn: float = Field(..., gt=0, description="实测张力 kN") note: Optional[str] = None class QualityCalibRequest(BaseModel): thickness: float = Field(..., gt=0) avg_speed: float = Field(..., gt=0) acid_conc_avg: float = Field(..., gt=0) acid_temp_avg: float = Field(..., gt=0) scale_weight: Optional[float] = 8.5 fe_conc_avg: Optional[float] = 60.0 steel_grade: Optional[str] = "_default" actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C") note: Optional[str] = None # ───────────────────────────────────────────────────────────────────────────── # Helper: append calibration history # ───────────────────────────────────────────────────────────────────────────── def _append_history(model_key: str, k_before, k_after, input_data: dict, note: str = ""): cal = _load_cal() history = cal.get("history", []) history.insert(0, { "ts": datetime.now().isoformat(timespec="seconds"), "model": model_key, "k_before": k_before, "k_after": k_after, "input": input_data, "note": note or "", }) cal["history"] = history[:100] _save_cal(cal) # ───────────────────────────────────────────────────────────────────────────── # Prediction endpoints # ───────────────────────────────────────────────────────────────────────────── @router.post("/acid-speed", response_model=Response[dict]) async def predict_acid_speed(body: AcidSpeedRequest, _=Depends(get_current_user)): try: model = AcidSpeedModel( thickness=body.thickness, width=body.width, steel_grade=body.steel_grade, acid_conc_list=body.acid_conc_list, acid_temp_list=body.acid_temp_list, scale_weight=body.scale_weight, target_pi=body.target_pi, ) result = model.calculate() except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) return Response.ok(result) @router.post("/tension", response_model=Response[dict]) async def predict_tension(body: TensionRequest, _=Depends(get_current_user)): model = TensionModel( thickness=body.thickness, width=body.width, yield_strength=body.yield_strength, tension_coef=body.tension_coef, steel_grade=body.steel_grade, ) return Response.ok(model.calculate()) @router.post("/quality", response_model=Response[dict]) async def predict_quality(body: QualityRequest, _=Depends(get_current_user)): model = QualityPredictionModel( thickness=body.thickness, avg_speed=body.avg_speed, acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg, scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg, steel_grade=body.steel_grade, ) return Response.ok(model.calculate()) @router.post("/consumption", response_model=Response[dict]) async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_user)): model = AcidConsumptionModel( thickness=body.thickness, width=body.width, coil_weight_kg=body.coil_weight_kg, has_regen_station=body.has_regen_station, fe_conc_avg=body.fe_conc_avg, ) return Response.ok(model.calculate()) # ───────────────────────────────────────────────────────────────────────────── # Calibration endpoints # ───────────────────────────────────────────────────────────────────────────── @router.get("/calibration", response_model=Response[dict]) async def get_calibration(_=Depends(get_current_user)): """返回各模型当前校准系数(按钢种)和历史记录。""" cal = _load_cal() return Response.ok({ "kcal": cal.get("kcal", {}), "phys": cal.get("phys", {}), "history": cal.get("history", []), }) @router.get("/calibration/samples", response_model=Response[dict]) async def get_calibration_samples(_=Depends(get_current_user)): """返回各模型 + 钢种的生产样本数量统计,以及重训所需的样本阈值。""" stats = get_sample_stats() return Response.ok({ "stats": stats, "fit_threshold": 10, "retrain_tip": "样本累积足够后,POST /api/prediction/retrain 可触发 ONNX 重训", }) @router.get("/calibration/phys-params", response_model=Response[dict]) async def get_phys_params(steel_grade: Optional[str] = None, _=Depends(get_current_user)): """查询物理参数(EA_R / K0 / N_CONC),可按钢种过滤。""" cal = _load_cal() phys = cal.get("phys", {}) if steel_grade: result = {} for m in ("acid_speed", "quality"): result[m] = _get_phys(m, steel_grade) return Response.ok({"steel_grade": steel_grade, "phys_params": result}) return Response.ok(phys) @router.post("/calibration/acid-speed", response_model=Response[dict]) async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)): """录入实测速度,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。""" try: model = AcidSpeedModel( thickness=body.thickness, width=body.width, steel_grade=body.steel_grade, acid_conc_list=body.acid_conc_list, acid_temp_list=body.acid_temp_list, scale_weight=body.scale_weight, ) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) k_before = model.K_cal predicted_speed = model.calculate()["max_speed"] k_after = model.calibrate( actual_max_speed=body.actual_max_speed, actual_quality_ok=body.actual_quality_ok, ) _append_history( f"acid_speed[{body.steel_grade}]", k_before, k_after, {"actual_speed": body.actual_max_speed, "quality_ok": body.actual_quality_ok, "predicted_speed": predicted_speed}, body.note or "", ) return Response.ok({ "steel_grade": body.steel_grade, "k_before": k_before, "k_after": k_after, "predicted_speed": predicted_speed, "adjustment": round((k_after / k_before - 1) * 100, 2), }) @router.post("/calibration/tension", response_model=Response[dict]) async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)): """录入实测张力,更新对应钢种 + 区段的 K_cal。""" model = TensionModel( thickness=body.thickness, width=body.width, yield_strength=body.yield_strength, tension_coef=body.tension_coef, steel_grade=body.steel_grade, ) calc = model.calculate() predicted_kn= calc["zones"].get(body.zone, {}).get("tension_kN", 0) k_before = model.zone_kcal.get(body.zone, 1.0) new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn) k_after = new_zone_kcal.get(body.zone, 1.0) _append_history( f"tension[{body.steel_grade}][{body.zone}]", k_before, k_after, {"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn}, body.note or "", ) return Response.ok({ "steel_grade": body.steel_grade, "zone": body.zone, "k_before": k_before, "k_after": k_after, "predicted_kn": predicted_kn, "measured_kn": body.measured_kn, "adjustment": round((k_after / k_before - 1) * 100, 2), "zone_kcal": new_zone_kcal, }) @router.post("/calibration/quality", response_model=Response[dict]) async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)): """录入实际质检等级,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。""" model = QualityPredictionModel( thickness=body.thickness, avg_speed=body.avg_speed, acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg, scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg, steel_grade=body.steel_grade, ) k_before = model.K_cal calc = model.calculate() predicted_grade = calc["overall_grade"] k_after = model.calibrate(actual_grade=body.actual_grade) _append_history( f"quality[{body.steel_grade}]", k_before, k_after, {"actual_grade": body.actual_grade, "predicted_grade": predicted_grade}, body.note or "", ) return Response.ok({ "steel_grade": body.steel_grade, "k_before": k_before, "k_after": k_after, "predicted_grade": predicted_grade, "actual_grade": body.actual_grade, "adjustment": round((k_after / k_before - 1) * 100, 2), }) @router.post("/calibration/fit-phys/{model_key}", response_model=Response[dict]) async def fit_phys_params_api(model_key: str, steel_grade: str, _=Depends(get_current_user)): """ 手动触发指定模型 + 钢种的物理参数拟合(自动触发也会调用此逻辑)。 model_key: acid_speed | quality """ if model_key == "acid_speed": result = fit_acid_phys_params(steel_grade) elif model_key == "quality": result = fit_quality_phys_params(steel_grade) else: raise HTTPException(status_code=404, detail="model_key 仅支持 acid_speed / quality") if result is None: from app.services.prediction import _FIT_MIN_SAMPLES return Response.ok({ "fitted": False, "reason": f"样本不足,需 ≥{_FIT_MIN_SAMPLES} 条,请继续录入校准数据", }) return Response.ok({"fitted": True, "steel_grade": steel_grade, "phys_params": result}) @router.post("/calibration/reset/{model_key}", response_model=Response[dict]) async def reset_calibration(model_key: str, steel_grade: Optional[str] = None, _=Depends(get_current_user)): """ 重置校准系数为 1.0。 steel_grade 为空时重置该模型所有钢种;否则只重置指定钢种。 """ cal = _load_cal() if model_key == "tension": for z in TENSION_ZONES: key = f"tension_{z}" if steel_grade: cal.setdefault("kcal", {}).setdefault(key, {}).pop(steel_grade, None) else: cal.setdefault("kcal", {})[key] = {"_default": 1.0} _append_history("tension", None, 1.0, {"action": "reset", "steel_grade": steel_grade}) elif model_key in ("acid_speed", "quality"): if steel_grade: cal.setdefault("kcal", {}).setdefault(model_key, {}).pop(steel_grade, None) cal.setdefault("phys", {}).setdefault(model_key, {}).pop(steel_grade, None) else: cal.setdefault("kcal", {})[model_key] = {"_default": 1.0} from app.services.prediction import _DEFAULT_PHYS cal.setdefault("phys", {})[model_key] = {"_default": _DEFAULT_PHYS.copy()} _append_history(model_key, None, 1.0, {"action": "reset", "steel_grade": steel_grade or "all"}) else: raise HTTPException(status_code=404, detail="未知模型") _save_cal(cal) return Response.ok({"model": model_key, "steel_grade": steel_grade or "all", "reset": True}) # ───────────────────────────────────────────────────────────────────────────── # 数据飞轮:ONNX 重训端点 # ───────────────────────────────────────────────────────────────────────────── _retrain_lock = asyncio.Lock() _retrain_status: dict = {"running": False, "last_ts": None, "last_result": None} def _run_retrain(): """在子进程中运行 train_models.py --use-real-data,完成后热重载 ONNX。""" global _retrain_status _retrain_status["running"] = True _retrain_status["last_ts"] = datetime.now().isoformat(timespec="seconds") try: result = subprocess.run( [sys.executable, str(_BACKEND_DIR / "train_models.py"), "--use-real-data"], capture_output=True, text=True, timeout=600, ) ok = result.returncode == 0 _retrain_status["last_result"] = { "success": ok, "stdout": result.stdout[-2000:] if result.stdout else "", "stderr": result.stderr[-1000:] if result.stderr else "", } if ok: reload_onnx() except subprocess.TimeoutExpired: _retrain_status["last_result"] = {"success": False, "stderr": "训练超时(>600s)"} except Exception as e: _retrain_status["last_result"] = {"success": False, "stderr": str(e)} finally: _retrain_status["running"] = False @router.post("/retrain", response_model=Response[dict]) async def trigger_retrain(background_tasks: BackgroundTasks, _=Depends(get_current_user)): """ 触发 ONNX 重训(使用 production_samples.jsonl 中的生产实绩)。 训练在后台子进程运行,完成后自动热重载模型。 """ if _retrain_status["running"]: raise HTTPException(status_code=409, detail="重训任务正在进行中,请稍后再试") stats = get_sample_stats() background_tasks.add_task(_run_retrain) return Response.ok({ "message": "重训任务已启动,完成后 ONNX 将自动热重载", "sample_stats": stats, }) @router.get("/retrain/status", response_model=Response[dict]) async def retrain_status(_=Depends(get_current_user)): """查询重训任务状态。""" return Response.ok(_retrain_status)