1. 按钢种分组 K_cal:cal_coeffs.json 升级为嵌套结构,
{kcal: {model: {_default, Q235, ...}}, phys: {...}},
旧平铺格式首次加载时自动迁移。
2. 物理参数自适应:EA_R/K0/N_CONC 按钢种网格拟合
(7×5×3=105 组合),每次校准追加样本到
production_samples.jsonl,≥10 条后自动触发拟合。
3. 数据飞轮:新增 POST /retrain 端点,后台子进程跑
train_models.py --use-real-data 混入实绩重训
(10× 权重),完成后 ONNX 热重载,无需重启服务。
新增端点:
GET /calibration/samples 样本数统计
GET /calibration/phys-params 物理参数查询
POST /calibration/fit-phys/{key} 手动触发物理参数拟合
POST /retrain 启动重训
GET /retrain/status 重训进度
模型类签名变更:
TensionModel / QualityPredictionModel 新增 steel_grade 参数
AcidConsumptionModel 新增 fe_conc_avg 参数
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
427 lines
18 KiB
Python
427 lines
18 KiB
Python
import asyncio
|
||
import subprocess
|
||
import sys
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
|
||
from app.schemas.common import Response
|
||
from app.services.auth_service import get_current_user
|
||
from app.services.prediction import (
|
||
AcidSpeedModel,
|
||
TensionModel,
|
||
QualityPredictionModel,
|
||
AcidConsumptionModel,
|
||
_load_cal,
|
||
_save_cal,
|
||
_get_phys,
|
||
append_sample,
|
||
get_sample_stats,
|
||
fit_acid_phys_params,
|
||
fit_quality_phys_params,
|
||
reload_onnx,
|
||
)
|
||
|
||
router = APIRouter()
|
||
|
||
TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3",
|
||
"rinse","leveler","s2_roller","outlet"]
|
||
_BACKEND_DIR = Path(__file__).parent.parent.parent
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Prediction request schemas
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
class AcidSpeedRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
width: float = Field(..., gt=0)
|
||
steel_grade: str
|
||
acid_conc_list: List[float]
|
||
acid_temp_list: List[float]
|
||
scale_weight: Optional[float] = 8.5
|
||
target_pi: Optional[float] = 95.0
|
||
|
||
|
||
class TensionRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
width: float = Field(..., gt=0)
|
||
yield_strength: float = Field(..., gt=0)
|
||
tension_coef: Optional[float] = 0.25
|
||
steel_grade: Optional[str] = "_default"
|
||
|
||
|
||
class QualityRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
avg_speed: float = Field(..., gt=0)
|
||
acid_conc_avg: float = Field(..., gt=0)
|
||
acid_temp_avg: float = Field(..., gt=0)
|
||
scale_weight: Optional[float] = 8.5
|
||
fe_conc_avg: Optional[float] = 60.0
|
||
steel_grade: Optional[str] = "_default"
|
||
|
||
|
||
class ConsumptionRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
width: float = Field(..., gt=0)
|
||
coil_weight_kg: float = Field(..., gt=0)
|
||
has_regen_station: Optional[bool] = True
|
||
fe_conc_avg: Optional[float] = 60.0
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Calibration request schemas
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
class AcidCalibRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
width: float = Field(..., gt=0)
|
||
steel_grade: str
|
||
acid_conc_list: List[float]
|
||
acid_temp_list: List[float]
|
||
scale_weight: Optional[float] = 8.5
|
||
actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min")
|
||
actual_quality_ok: bool = Field(..., description="该速度下质量是否合格")
|
||
note: Optional[str] = None
|
||
|
||
|
||
class TensionCalibRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
width: float = Field(..., gt=0)
|
||
yield_strength: float = Field(..., gt=0)
|
||
tension_coef: Optional[float] = 0.25
|
||
steel_grade: Optional[str] = "_default"
|
||
zone: str = Field(..., description="测量位置,如 s1_roller")
|
||
measured_kn: float = Field(..., gt=0, description="实测张力 kN")
|
||
note: Optional[str] = None
|
||
|
||
|
||
class QualityCalibRequest(BaseModel):
|
||
thickness: float = Field(..., gt=0)
|
||
avg_speed: float = Field(..., gt=0)
|
||
acid_conc_avg: float = Field(..., gt=0)
|
||
acid_temp_avg: float = Field(..., gt=0)
|
||
scale_weight: Optional[float] = 8.5
|
||
fe_conc_avg: Optional[float] = 60.0
|
||
steel_grade: Optional[str] = "_default"
|
||
actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C")
|
||
note: Optional[str] = None
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Helper: append calibration history
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
def _append_history(model_key: str, k_before, k_after,
|
||
input_data: dict, note: str = ""):
|
||
cal = _load_cal()
|
||
history = cal.get("history", [])
|
||
history.insert(0, {
|
||
"ts": datetime.now().isoformat(timespec="seconds"),
|
||
"model": model_key,
|
||
"k_before": k_before,
|
||
"k_after": k_after,
|
||
"input": input_data,
|
||
"note": note or "",
|
||
})
|
||
cal["history"] = history[:100]
|
||
_save_cal(cal)
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Prediction endpoints
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
@router.post("/acid-speed", response_model=Response[dict])
|
||
async def predict_acid_speed(body: AcidSpeedRequest, _=Depends(get_current_user)):
|
||
try:
|
||
model = AcidSpeedModel(
|
||
thickness=body.thickness, width=body.width,
|
||
steel_grade=body.steel_grade,
|
||
acid_conc_list=body.acid_conc_list,
|
||
acid_temp_list=body.acid_temp_list,
|
||
scale_weight=body.scale_weight, target_pi=body.target_pi,
|
||
)
|
||
result = model.calculate()
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=422, detail=str(e))
|
||
return Response.ok(result)
|
||
|
||
|
||
@router.post("/tension", response_model=Response[dict])
|
||
async def predict_tension(body: TensionRequest, _=Depends(get_current_user)):
|
||
model = TensionModel(
|
||
thickness=body.thickness, width=body.width,
|
||
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
|
||
steel_grade=body.steel_grade,
|
||
)
|
||
return Response.ok(model.calculate())
|
||
|
||
|
||
@router.post("/quality", response_model=Response[dict])
|
||
async def predict_quality(body: QualityRequest, _=Depends(get_current_user)):
|
||
model = QualityPredictionModel(
|
||
thickness=body.thickness, avg_speed=body.avg_speed,
|
||
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
|
||
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
|
||
steel_grade=body.steel_grade,
|
||
)
|
||
return Response.ok(model.calculate())
|
||
|
||
|
||
@router.post("/consumption", response_model=Response[dict])
|
||
async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_user)):
|
||
model = AcidConsumptionModel(
|
||
thickness=body.thickness, width=body.width,
|
||
coil_weight_kg=body.coil_weight_kg,
|
||
has_regen_station=body.has_regen_station,
|
||
fe_conc_avg=body.fe_conc_avg,
|
||
)
|
||
return Response.ok(model.calculate())
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# Calibration endpoints
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
@router.get("/calibration", response_model=Response[dict])
|
||
async def get_calibration(_=Depends(get_current_user)):
|
||
"""返回各模型当前校准系数(按钢种)和历史记录。"""
|
||
cal = _load_cal()
|
||
return Response.ok({
|
||
"kcal": cal.get("kcal", {}),
|
||
"phys": cal.get("phys", {}),
|
||
"history": cal.get("history", []),
|
||
})
|
||
|
||
|
||
@router.get("/calibration/samples", response_model=Response[dict])
|
||
async def get_calibration_samples(_=Depends(get_current_user)):
|
||
"""返回各模型 + 钢种的生产样本数量统计,以及重训所需的样本阈值。"""
|
||
stats = get_sample_stats()
|
||
return Response.ok({
|
||
"stats": stats,
|
||
"fit_threshold": 10,
|
||
"retrain_tip": "样本累积足够后,POST /api/prediction/retrain 可触发 ONNX 重训",
|
||
})
|
||
|
||
|
||
@router.get("/calibration/phys-params", response_model=Response[dict])
|
||
async def get_phys_params(steel_grade: Optional[str] = None, _=Depends(get_current_user)):
|
||
"""查询物理参数(EA_R / K0 / N_CONC),可按钢种过滤。"""
|
||
cal = _load_cal()
|
||
phys = cal.get("phys", {})
|
||
if steel_grade:
|
||
result = {}
|
||
for m in ("acid_speed", "quality"):
|
||
result[m] = _get_phys(m, steel_grade)
|
||
return Response.ok({"steel_grade": steel_grade, "phys_params": result})
|
||
return Response.ok(phys)
|
||
|
||
|
||
@router.post("/calibration/acid-speed", response_model=Response[dict])
|
||
async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)):
|
||
"""录入实测速度,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。"""
|
||
try:
|
||
model = AcidSpeedModel(
|
||
thickness=body.thickness, width=body.width,
|
||
steel_grade=body.steel_grade,
|
||
acid_conc_list=body.acid_conc_list,
|
||
acid_temp_list=body.acid_temp_list,
|
||
scale_weight=body.scale_weight,
|
||
)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=422, detail=str(e))
|
||
|
||
k_before = model.K_cal
|
||
predicted_speed = model.calculate()["max_speed"]
|
||
k_after = model.calibrate(
|
||
actual_max_speed=body.actual_max_speed,
|
||
actual_quality_ok=body.actual_quality_ok,
|
||
)
|
||
_append_history(
|
||
f"acid_speed[{body.steel_grade}]", k_before, k_after,
|
||
{"actual_speed": body.actual_max_speed,
|
||
"quality_ok": body.actual_quality_ok,
|
||
"predicted_speed": predicted_speed},
|
||
body.note or "",
|
||
)
|
||
return Response.ok({
|
||
"steel_grade": body.steel_grade,
|
||
"k_before": k_before,
|
||
"k_after": k_after,
|
||
"predicted_speed": predicted_speed,
|
||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||
})
|
||
|
||
|
||
@router.post("/calibration/tension", response_model=Response[dict])
|
||
async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)):
|
||
"""录入实测张力,更新对应钢种 + 区段的 K_cal。"""
|
||
model = TensionModel(
|
||
thickness=body.thickness, width=body.width,
|
||
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
|
||
steel_grade=body.steel_grade,
|
||
)
|
||
calc = model.calculate()
|
||
predicted_kn= calc["zones"].get(body.zone, {}).get("tension_kN", 0)
|
||
k_before = model.zone_kcal.get(body.zone, 1.0)
|
||
new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn)
|
||
k_after = new_zone_kcal.get(body.zone, 1.0)
|
||
_append_history(
|
||
f"tension[{body.steel_grade}][{body.zone}]", k_before, k_after,
|
||
{"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn},
|
||
body.note or "",
|
||
)
|
||
return Response.ok({
|
||
"steel_grade": body.steel_grade,
|
||
"zone": body.zone,
|
||
"k_before": k_before,
|
||
"k_after": k_after,
|
||
"predicted_kn": predicted_kn,
|
||
"measured_kn": body.measured_kn,
|
||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||
"zone_kcal": new_zone_kcal,
|
||
})
|
||
|
||
|
||
@router.post("/calibration/quality", response_model=Response[dict])
|
||
async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)):
|
||
"""录入实际质检等级,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。"""
|
||
model = QualityPredictionModel(
|
||
thickness=body.thickness, avg_speed=body.avg_speed,
|
||
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
|
||
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
|
||
steel_grade=body.steel_grade,
|
||
)
|
||
k_before = model.K_cal
|
||
calc = model.calculate()
|
||
predicted_grade = calc["overall_grade"]
|
||
k_after = model.calibrate(actual_grade=body.actual_grade)
|
||
_append_history(
|
||
f"quality[{body.steel_grade}]", k_before, k_after,
|
||
{"actual_grade": body.actual_grade, "predicted_grade": predicted_grade},
|
||
body.note or "",
|
||
)
|
||
return Response.ok({
|
||
"steel_grade": body.steel_grade,
|
||
"k_before": k_before,
|
||
"k_after": k_after,
|
||
"predicted_grade": predicted_grade,
|
||
"actual_grade": body.actual_grade,
|
||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||
})
|
||
|
||
|
||
@router.post("/calibration/fit-phys/{model_key}", response_model=Response[dict])
|
||
async def fit_phys_params_api(model_key: str, steel_grade: str, _=Depends(get_current_user)):
|
||
"""
|
||
手动触发指定模型 + 钢种的物理参数拟合(自动触发也会调用此逻辑)。
|
||
model_key: acid_speed | quality
|
||
"""
|
||
if model_key == "acid_speed":
|
||
result = fit_acid_phys_params(steel_grade)
|
||
elif model_key == "quality":
|
||
result = fit_quality_phys_params(steel_grade)
|
||
else:
|
||
raise HTTPException(status_code=404, detail="model_key 仅支持 acid_speed / quality")
|
||
|
||
if result is None:
|
||
from app.services.prediction import _FIT_MIN_SAMPLES
|
||
return Response.ok({
|
||
"fitted": False,
|
||
"reason": f"样本不足,需 ≥{_FIT_MIN_SAMPLES} 条,请继续录入校准数据",
|
||
})
|
||
return Response.ok({"fitted": True, "steel_grade": steel_grade, "phys_params": result})
|
||
|
||
|
||
@router.post("/calibration/reset/{model_key}", response_model=Response[dict])
|
||
async def reset_calibration(model_key: str, steel_grade: Optional[str] = None,
|
||
_=Depends(get_current_user)):
|
||
"""
|
||
重置校准系数为 1.0。
|
||
steel_grade 为空时重置该模型所有钢种;否则只重置指定钢种。
|
||
"""
|
||
cal = _load_cal()
|
||
if model_key == "tension":
|
||
for z in TENSION_ZONES:
|
||
key = f"tension_{z}"
|
||
if steel_grade:
|
||
cal.setdefault("kcal", {}).setdefault(key, {}).pop(steel_grade, None)
|
||
else:
|
||
cal.setdefault("kcal", {})[key] = {"_default": 1.0}
|
||
_append_history("tension", None, 1.0, {"action": "reset", "steel_grade": steel_grade})
|
||
elif model_key in ("acid_speed", "quality"):
|
||
if steel_grade:
|
||
cal.setdefault("kcal", {}).setdefault(model_key, {}).pop(steel_grade, None)
|
||
cal.setdefault("phys", {}).setdefault(model_key, {}).pop(steel_grade, None)
|
||
else:
|
||
cal.setdefault("kcal", {})[model_key] = {"_default": 1.0}
|
||
from app.services.prediction import _DEFAULT_PHYS
|
||
cal.setdefault("phys", {})[model_key] = {"_default": _DEFAULT_PHYS.copy()}
|
||
_append_history(model_key, None, 1.0,
|
||
{"action": "reset", "steel_grade": steel_grade or "all"})
|
||
else:
|
||
raise HTTPException(status_code=404, detail="未知模型")
|
||
_save_cal(cal)
|
||
return Response.ok({"model": model_key, "steel_grade": steel_grade or "all", "reset": True})
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
# 数据飞轮:ONNX 重训端点
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
_retrain_lock = asyncio.Lock()
|
||
_retrain_status: dict = {"running": False, "last_ts": None, "last_result": None}
|
||
|
||
|
||
def _run_retrain():
|
||
"""在子进程中运行 train_models.py --use-real-data,完成后热重载 ONNX。"""
|
||
global _retrain_status
|
||
_retrain_status["running"] = True
|
||
_retrain_status["last_ts"] = datetime.now().isoformat(timespec="seconds")
|
||
try:
|
||
result = subprocess.run(
|
||
[sys.executable, str(_BACKEND_DIR / "train_models.py"), "--use-real-data"],
|
||
capture_output=True, text=True, timeout=600,
|
||
)
|
||
ok = result.returncode == 0
|
||
_retrain_status["last_result"] = {
|
||
"success": ok,
|
||
"stdout": result.stdout[-2000:] if result.stdout else "",
|
||
"stderr": result.stderr[-1000:] if result.stderr else "",
|
||
}
|
||
if ok:
|
||
reload_onnx()
|
||
except subprocess.TimeoutExpired:
|
||
_retrain_status["last_result"] = {"success": False, "stderr": "训练超时(>600s)"}
|
||
except Exception as e:
|
||
_retrain_status["last_result"] = {"success": False, "stderr": str(e)}
|
||
finally:
|
||
_retrain_status["running"] = False
|
||
|
||
|
||
@router.post("/retrain", response_model=Response[dict])
|
||
async def trigger_retrain(background_tasks: BackgroundTasks, _=Depends(get_current_user)):
|
||
"""
|
||
触发 ONNX 重训(使用 production_samples.jsonl 中的生产实绩)。
|
||
训练在后台子进程运行,完成后自动热重载模型。
|
||
"""
|
||
if _retrain_status["running"]:
|
||
raise HTTPException(status_code=409, detail="重训任务正在进行中,请稍后再试")
|
||
stats = get_sample_stats()
|
||
background_tasks.add_task(_run_retrain)
|
||
return Response.ok({
|
||
"message": "重训任务已启动,完成后 ONNX 将自动热重载",
|
||
"sample_stats": stats,
|
||
})
|
||
|
||
|
||
@router.get("/retrain/status", response_model=Response[dict])
|
||
async def retrain_status(_=Depends(get_current_user)):
|
||
"""查询重训任务状态。"""
|
||
return Response.ok(_retrain_status)
|