feat(prediction): 三层校准体系 + 按钢种分组 + 数据飞轮
1. 按钢种分组 K_cal:cal_coeffs.json 升级为嵌套结构,
{kcal: {model: {_default, Q235, ...}}, phys: {...}},
旧平铺格式首次加载时自动迁移。
2. 物理参数自适应:EA_R/K0/N_CONC 按钢种网格拟合
(7×5×3=105 组合),每次校准追加样本到
production_samples.jsonl,≥10 条后自动触发拟合。
3. 数据飞轮:新增 POST /retrain 端点,后台子进程跑
train_models.py --use-real-data 混入实绩重训
(10× 权重),完成后 ONNX 热重载,无需重启服务。
新增端点:
GET /calibration/samples 样本数统计
GET /calibration/phys-params 物理参数查询
POST /calibration/fit-phys/{key} 手动触发物理参数拟合
POST /retrain 启动重训
GET /retrain/status 重训进度
模型类签名变更:
TensionModel / QualityPredictionModel 新增 steel_grade 参数
AcidConsumptionModel 新增 fe_conc_avg 参数
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,12 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.common import Response
|
||||
from app.services.auth_service import get_current_user
|
||||
@@ -12,10 +17,20 @@ from app.services.prediction import (
|
||||
AcidConsumptionModel,
|
||||
_load_cal,
|
||||
_save_cal,
|
||||
_get_phys,
|
||||
append_sample,
|
||||
get_sample_stats,
|
||||
fit_acid_phys_params,
|
||||
fit_quality_phys_params,
|
||||
reload_onnx,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3",
|
||||
"rinse","leveler","s2_roller","outlet"]
|
||||
_BACKEND_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Prediction request schemas
|
||||
@@ -36,6 +51,7 @@ class TensionRequest(BaseModel):
|
||||
width: float = Field(..., gt=0)
|
||||
yield_strength: float = Field(..., gt=0)
|
||||
tension_coef: Optional[float] = 0.25
|
||||
steel_grade: Optional[str] = "_default"
|
||||
|
||||
|
||||
class QualityRequest(BaseModel):
|
||||
@@ -44,13 +60,16 @@ class QualityRequest(BaseModel):
|
||||
acid_conc_avg: float = Field(..., gt=0)
|
||||
acid_temp_avg: float = Field(..., gt=0)
|
||||
scale_weight: Optional[float] = 8.5
|
||||
fe_conc_avg: Optional[float] = 60.0
|
||||
steel_grade: Optional[str] = "_default"
|
||||
|
||||
|
||||
class ConsumptionRequest(BaseModel):
|
||||
thickness: float = Field(..., gt=0)
|
||||
width: float = Field(..., gt=0)
|
||||
coil_weight_kg: float = Field(..., gt=0)
|
||||
has_regen_station: Optional[bool] = True
|
||||
has_regen_station: Optional[bool] = True
|
||||
fe_conc_avg: Optional[float] = 60.0
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -58,17 +77,15 @@ class ConsumptionRequest(BaseModel):
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class AcidCalibRequest(BaseModel):
|
||||
# 需要重建模型实例的上下文参数
|
||||
thickness: float = Field(..., gt=0)
|
||||
width: float = Field(..., gt=0)
|
||||
steel_grade: str
|
||||
acid_conc_list: List[float]
|
||||
acid_temp_list: List[float]
|
||||
scale_weight: Optional[float] = 8.5
|
||||
# 校准输入
|
||||
actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min")
|
||||
actual_quality_ok: bool = Field(..., description="该速度下质量是否合格")
|
||||
note: Optional[str] = None
|
||||
actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min")
|
||||
actual_quality_ok: bool = Field(..., description="该速度下质量是否合格")
|
||||
note: Optional[str] = None
|
||||
|
||||
|
||||
class TensionCalibRequest(BaseModel):
|
||||
@@ -76,6 +93,7 @@ class TensionCalibRequest(BaseModel):
|
||||
width: float = Field(..., gt=0)
|
||||
yield_strength: float = Field(..., gt=0)
|
||||
tension_coef: Optional[float] = 0.25
|
||||
steel_grade: Optional[str] = "_default"
|
||||
zone: str = Field(..., description="测量位置,如 s1_roller")
|
||||
measured_kn: float = Field(..., gt=0, description="实测张力 kN")
|
||||
note: Optional[str] = None
|
||||
@@ -87,6 +105,8 @@ class QualityCalibRequest(BaseModel):
|
||||
acid_conc_avg: float = Field(..., gt=0)
|
||||
acid_temp_avg: float = Field(..., gt=0)
|
||||
scale_weight: Optional[float] = 8.5
|
||||
fe_conc_avg: Optional[float] = 60.0
|
||||
steel_grade: Optional[str] = "_default"
|
||||
actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C")
|
||||
note: Optional[str] = None
|
||||
|
||||
@@ -95,9 +115,9 @@ class QualityCalibRequest(BaseModel):
|
||||
# Helper: append calibration history
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _append_history(model_key: str, k_before: float, k_after: float,
|
||||
def _append_history(model_key: str, k_before, k_after,
|
||||
input_data: dict, note: str = ""):
|
||||
cal = _load_cal()
|
||||
cal = _load_cal()
|
||||
history = cal.get("history", [])
|
||||
history.insert(0, {
|
||||
"ts": datetime.now().isoformat(timespec="seconds"),
|
||||
@@ -136,6 +156,7 @@ async def predict_tension(body: TensionRequest, _=Depends(get_current_user)):
|
||||
model = TensionModel(
|
||||
thickness=body.thickness, width=body.width,
|
||||
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
|
||||
steel_grade=body.steel_grade,
|
||||
)
|
||||
return Response.ok(model.calculate())
|
||||
|
||||
@@ -145,7 +166,8 @@ async def predict_quality(body: QualityRequest, _=Depends(get_current_user)):
|
||||
model = QualityPredictionModel(
|
||||
thickness=body.thickness, avg_speed=body.avg_speed,
|
||||
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
|
||||
scale_weight=body.scale_weight,
|
||||
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
|
||||
steel_grade=body.steel_grade,
|
||||
)
|
||||
return Response.ok(model.calculate())
|
||||
|
||||
@@ -156,6 +178,7 @@ async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_us
|
||||
thickness=body.thickness, width=body.width,
|
||||
coil_weight_kg=body.coil_weight_kg,
|
||||
has_regen_station=body.has_regen_station,
|
||||
fe_conc_avg=body.fe_conc_avg,
|
||||
)
|
||||
return Response.ok(model.calculate())
|
||||
|
||||
@@ -164,27 +187,44 @@ async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_us
|
||||
# Calibration endpoints
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3","rinse","leveler","s2_roller","outlet"]
|
||||
|
||||
|
||||
@router.get("/calibration", response_model=Response[dict])
|
||||
async def get_calibration(_=Depends(get_current_user)):
|
||||
"""返回各模型当前校准系数和历史记录"""
|
||||
"""返回各模型当前校准系数(按钢种)和历史记录。"""
|
||||
cal = _load_cal()
|
||||
tension_zone_kcal = {
|
||||
z: cal.get(f"tension_zone_{z}", 1.0) for z in TENSION_ZONES
|
||||
}
|
||||
return Response.ok({
|
||||
"acid_speed_kcal": cal.get("acid_speed_kcal", 1.0),
|
||||
"tension_zone_kcal": tension_zone_kcal,
|
||||
"quality_kcal": cal.get("quality_kcal", 1.0),
|
||||
"history": cal.get("history", []),
|
||||
"kcal": cal.get("kcal", {}),
|
||||
"phys": cal.get("phys", {}),
|
||||
"history": cal.get("history", []),
|
||||
})
|
||||
|
||||
|
||||
@router.get("/calibration/samples", response_model=Response[dict])
|
||||
async def get_calibration_samples(_=Depends(get_current_user)):
|
||||
"""返回各模型 + 钢种的生产样本数量统计,以及重训所需的样本阈值。"""
|
||||
stats = get_sample_stats()
|
||||
return Response.ok({
|
||||
"stats": stats,
|
||||
"fit_threshold": 10,
|
||||
"retrain_tip": "样本累积足够后,POST /api/prediction/retrain 可触发 ONNX 重训",
|
||||
})
|
||||
|
||||
|
||||
@router.get("/calibration/phys-params", response_model=Response[dict])
|
||||
async def get_phys_params(steel_grade: Optional[str] = None, _=Depends(get_current_user)):
|
||||
"""查询物理参数(EA_R / K0 / N_CONC),可按钢种过滤。"""
|
||||
cal = _load_cal()
|
||||
phys = cal.get("phys", {})
|
||||
if steel_grade:
|
||||
result = {}
|
||||
for m in ("acid_speed", "quality"):
|
||||
result[m] = _get_phys(m, steel_grade)
|
||||
return Response.ok({"steel_grade": steel_grade, "phys_params": result})
|
||||
return Response.ok(phys)
|
||||
|
||||
|
||||
@router.post("/calibration/acid-speed", response_model=Response[dict])
|
||||
async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)):
|
||||
"""录入实测数据,更新酸洗速度模型校准系数"""
|
||||
"""录入实测速度,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。"""
|
||||
try:
|
||||
model = AcidSpeedModel(
|
||||
thickness=body.thickness, width=body.width,
|
||||
@@ -196,76 +236,78 @@ async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_use
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e))
|
||||
|
||||
k_before = model.K_cal
|
||||
k_before = model.K_cal
|
||||
predicted_speed = model.calculate()["max_speed"]
|
||||
k_after = model.calibrate(
|
||||
k_after = model.calibrate(
|
||||
actual_max_speed=body.actual_max_speed,
|
||||
actual_quality_ok=body.actual_quality_ok,
|
||||
)
|
||||
_append_history(
|
||||
"acid_speed", k_before, k_after,
|
||||
f"acid_speed[{body.steel_grade}]", k_before, k_after,
|
||||
{"actual_speed": body.actual_max_speed,
|
||||
"quality_ok": body.actual_quality_ok,
|
||||
"quality_ok": body.actual_quality_ok,
|
||||
"predicted_speed": predicted_speed},
|
||||
body.note or "",
|
||||
)
|
||||
return Response.ok({
|
||||
"k_before": k_before,
|
||||
"k_after": k_after,
|
||||
"steel_grade": body.steel_grade,
|
||||
"k_before": k_before,
|
||||
"k_after": k_after,
|
||||
"predicted_speed": predicted_speed,
|
||||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/calibration/tension", response_model=Response[dict])
|
||||
async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)):
|
||||
"""录入实测张力,仅更新指定区段的校准系数"""
|
||||
"""录入实测张力,更新对应钢种 + 区段的 K_cal。"""
|
||||
model = TensionModel(
|
||||
thickness=body.thickness, width=body.width,
|
||||
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
|
||||
steel_grade=body.steel_grade,
|
||||
)
|
||||
calc = model.calculate()
|
||||
predicted_kn = calc["zones"].get(body.zone, {}).get("tension_kN", 0)
|
||||
k_before = model.zone_kcal.get(body.zone, 1.0)
|
||||
calc = model.calculate()
|
||||
predicted_kn= calc["zones"].get(body.zone, {}).get("tension_kN", 0)
|
||||
k_before = model.zone_kcal.get(body.zone, 1.0)
|
||||
new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn)
|
||||
k_after = new_zone_kcal.get(body.zone, 1.0)
|
||||
k_after = new_zone_kcal.get(body.zone, 1.0)
|
||||
_append_history(
|
||||
"tension", k_before, k_after,
|
||||
{"zone": body.zone,
|
||||
"measured_kn": body.measured_kn,
|
||||
"predicted_kn": predicted_kn},
|
||||
f"tension[{body.steel_grade}][{body.zone}]", k_before, k_after,
|
||||
{"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn},
|
||||
body.note or "",
|
||||
)
|
||||
return Response.ok({
|
||||
"zone": body.zone,
|
||||
"k_before": k_before,
|
||||
"k_after": k_after,
|
||||
"predicted_kn": predicted_kn,
|
||||
"measured_kn": body.measured_kn,
|
||||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||||
"zone_kcal": new_zone_kcal,
|
||||
"steel_grade": body.steel_grade,
|
||||
"zone": body.zone,
|
||||
"k_before": k_before,
|
||||
"k_after": k_after,
|
||||
"predicted_kn": predicted_kn,
|
||||
"measured_kn": body.measured_kn,
|
||||
"adjustment": round((k_after / k_before - 1) * 100, 2),
|
||||
"zone_kcal": new_zone_kcal,
|
||||
})
|
||||
|
||||
|
||||
@router.post("/calibration/quality", response_model=Response[dict])
|
||||
async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)):
|
||||
"""录入实际质检等级,更新质量模型校准系数"""
|
||||
"""录入实际质检等级,更新对应钢种 K_cal;样本 ≥10 后自动触发物理参数拟合。"""
|
||||
model = QualityPredictionModel(
|
||||
thickness=body.thickness, avg_speed=body.avg_speed,
|
||||
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
|
||||
scale_weight=body.scale_weight,
|
||||
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
|
||||
steel_grade=body.steel_grade,
|
||||
)
|
||||
k_before = model.K_cal
|
||||
calc = model.calculate()
|
||||
k_before = model.K_cal
|
||||
calc = model.calculate()
|
||||
predicted_grade = calc["overall_grade"]
|
||||
k_after = model.calibrate(actual_grade=body.actual_grade)
|
||||
k_after = model.calibrate(actual_grade=body.actual_grade)
|
||||
_append_history(
|
||||
"quality", k_before, k_after,
|
||||
{"actual_grade": body.actual_grade,
|
||||
"predicted_grade": predicted_grade},
|
||||
f"quality[{body.steel_grade}]", k_before, k_after,
|
||||
{"actual_grade": body.actual_grade, "predicted_grade": predicted_grade},
|
||||
body.note or "",
|
||||
)
|
||||
return Response.ok({
|
||||
"steel_grade": body.steel_grade,
|
||||
"k_before": k_before,
|
||||
"k_after": k_after,
|
||||
"predicted_grade": predicted_grade,
|
||||
@@ -274,21 +316,111 @@ async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_use
|
||||
})
|
||||
|
||||
|
||||
@router.post("/calibration/fit-phys/{model_key}", response_model=Response[dict])
|
||||
async def fit_phys_params_api(model_key: str, steel_grade: str, _=Depends(get_current_user)):
|
||||
"""
|
||||
手动触发指定模型 + 钢种的物理参数拟合(自动触发也会调用此逻辑)。
|
||||
model_key: acid_speed | quality
|
||||
"""
|
||||
if model_key == "acid_speed":
|
||||
result = fit_acid_phys_params(steel_grade)
|
||||
elif model_key == "quality":
|
||||
result = fit_quality_phys_params(steel_grade)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="model_key 仅支持 acid_speed / quality")
|
||||
|
||||
if result is None:
|
||||
from app.services.prediction import _FIT_MIN_SAMPLES
|
||||
return Response.ok({
|
||||
"fitted": False,
|
||||
"reason": f"样本不足,需 ≥{_FIT_MIN_SAMPLES} 条,请继续录入校准数据",
|
||||
})
|
||||
return Response.ok({"fitted": True, "steel_grade": steel_grade, "phys_params": result})
|
||||
|
||||
|
||||
@router.post("/calibration/reset/{model_key}", response_model=Response[dict])
|
||||
async def reset_calibration(model_key: str, _=Depends(get_current_user)):
|
||||
"""将指定模型的校准系数全部重置为 1.0"""
|
||||
async def reset_calibration(model_key: str, steel_grade: Optional[str] = None,
|
||||
_=Depends(get_current_user)):
|
||||
"""
|
||||
重置校准系数为 1.0。
|
||||
steel_grade 为空时重置该模型所有钢种;否则只重置指定钢种。
|
||||
"""
|
||||
cal = _load_cal()
|
||||
if model_key == "tension":
|
||||
# 重置所有区段
|
||||
for z in TENSION_ZONES:
|
||||
cal[f"tension_zone_{z}"] = 1.0
|
||||
_append_history("tension", None, 1.0, {"action": "reset_all_zones"})
|
||||
key = f"tension_{z}"
|
||||
if steel_grade:
|
||||
cal.setdefault("kcal", {}).setdefault(key, {}).pop(steel_grade, None)
|
||||
else:
|
||||
cal.setdefault("kcal", {})[key] = {"_default": 1.0}
|
||||
_append_history("tension", None, 1.0, {"action": "reset", "steel_grade": steel_grade})
|
||||
elif model_key in ("acid_speed", "quality"):
|
||||
key = f"{model_key}_kcal"
|
||||
k_before = cal.get(key, 1.0)
|
||||
cal[key] = 1.0
|
||||
_append_history(model_key, k_before, 1.0, {"action": "reset"})
|
||||
if steel_grade:
|
||||
cal.setdefault("kcal", {}).setdefault(model_key, {}).pop(steel_grade, None)
|
||||
cal.setdefault("phys", {}).setdefault(model_key, {}).pop(steel_grade, None)
|
||||
else:
|
||||
cal.setdefault("kcal", {})[model_key] = {"_default": 1.0}
|
||||
from app.services.prediction import _DEFAULT_PHYS
|
||||
cal.setdefault("phys", {})[model_key] = {"_default": _DEFAULT_PHYS.copy()}
|
||||
_append_history(model_key, None, 1.0,
|
||||
{"action": "reset", "steel_grade": steel_grade or "all"})
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="未知模型")
|
||||
_save_cal(cal)
|
||||
return Response.ok({"model": model_key, "reset": True})
|
||||
return Response.ok({"model": model_key, "steel_grade": steel_grade or "all", "reset": True})
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 数据飞轮:ONNX 重训端点
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_retrain_lock = asyncio.Lock()
|
||||
_retrain_status: dict = {"running": False, "last_ts": None, "last_result": None}
|
||||
|
||||
|
||||
def _run_retrain():
|
||||
"""在子进程中运行 train_models.py --use-real-data,完成后热重载 ONNX。"""
|
||||
global _retrain_status
|
||||
_retrain_status["running"] = True
|
||||
_retrain_status["last_ts"] = datetime.now().isoformat(timespec="seconds")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(_BACKEND_DIR / "train_models.py"), "--use-real-data"],
|
||||
capture_output=True, text=True, timeout=600,
|
||||
)
|
||||
ok = result.returncode == 0
|
||||
_retrain_status["last_result"] = {
|
||||
"success": ok,
|
||||
"stdout": result.stdout[-2000:] if result.stdout else "",
|
||||
"stderr": result.stderr[-1000:] if result.stderr else "",
|
||||
}
|
||||
if ok:
|
||||
reload_onnx()
|
||||
except subprocess.TimeoutExpired:
|
||||
_retrain_status["last_result"] = {"success": False, "stderr": "训练超时(>600s)"}
|
||||
except Exception as e:
|
||||
_retrain_status["last_result"] = {"success": False, "stderr": str(e)}
|
||||
finally:
|
||||
_retrain_status["running"] = False
|
||||
|
||||
|
||||
@router.post("/retrain", response_model=Response[dict])
|
||||
async def trigger_retrain(background_tasks: BackgroundTasks, _=Depends(get_current_user)):
|
||||
"""
|
||||
触发 ONNX 重训(使用 production_samples.jsonl 中的生产实绩)。
|
||||
训练在后台子进程运行,完成后自动热重载模型。
|
||||
"""
|
||||
if _retrain_status["running"]:
|
||||
raise HTTPException(status_code=409, detail="重训任务正在进行中,请稍后再试")
|
||||
stats = get_sample_stats()
|
||||
background_tasks.add_task(_run_retrain)
|
||||
return Response.ok({
|
||||
"message": "重训任务已启动,完成后 ONNX 将自动热重载",
|
||||
"sample_stats": stats,
|
||||
})
|
||||
|
||||
|
||||
@router.get("/retrain/status", response_model=Response[dict])
|
||||
async def retrain_status(_=Depends(get_current_user)):
|
||||
"""查询重训任务状态。"""
|
||||
return Response.ok(_retrain_status)
|
||||
|
||||
Reference in New Issue
Block a user