Files
pickling-mes/backend/app/api/prediction.py
wangyu f5c59db92b feat(prediction): 三层校准体系 + 按钢种分组 + 数据飞轮
1. 按钢种分组 K_cal:cal_coeffs.json 升级为嵌套结构,
   {kcal: {model: {_default, Q235, ...}}, phys: {...}},
   旧平铺格式首次加载时自动迁移。

2. 物理参数自适应:EA_R/K0/N_CONC 按钢种网格拟合
   (7×5×3=105 组合),每次校准追加样本到
   production_samples.jsonl,≥10 条后自动触发拟合。

3. 数据飞轮:新增 POST /retrain 端点,后台子进程跑
   train_models.py --use-real-data 混入实绩重训
   (10× 权重),完成后 ONNX 热重载,无需重启服务。

新增端点:
  GET  /calibration/samples         样本数统计
  GET  /calibration/phys-params     物理参数查询
  POST /calibration/fit-phys/{key}  手动触发物理参数拟合
  POST /retrain                     启动重训
  GET  /retrain/status              重训进度

模型类签名变更:
  TensionModel / QualityPredictionModel 新增 steel_grade 参数
  AcidConsumptionModel 新增 fe_conc_avg 参数

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-01 16:13:39 +08:00

427 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, Field
from app.schemas.common import Response
from app.services.auth_service import get_current_user
from app.services.prediction import (
AcidSpeedModel,
TensionModel,
QualityPredictionModel,
AcidConsumptionModel,
_load_cal,
_save_cal,
_get_phys,
append_sample,
get_sample_stats,
fit_acid_phys_params,
fit_quality_phys_params,
reload_onnx,
)
router = APIRouter()
TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3",
"rinse","leveler","s2_roller","outlet"]
_BACKEND_DIR = Path(__file__).parent.parent.parent
# ─────────────────────────────────────────────────────────────────────────────
# Prediction request schemas
# ─────────────────────────────────────────────────────────────────────────────
class AcidSpeedRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
steel_grade: str
acid_conc_list: List[float]
acid_temp_list: List[float]
scale_weight: Optional[float] = 8.5
target_pi: Optional[float] = 95.0
class TensionRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
yield_strength: float = Field(..., gt=0)
tension_coef: Optional[float] = 0.25
steel_grade: Optional[str] = "_default"
class QualityRequest(BaseModel):
thickness: float = Field(..., gt=0)
avg_speed: float = Field(..., gt=0)
acid_conc_avg: float = Field(..., gt=0)
acid_temp_avg: float = Field(..., gt=0)
scale_weight: Optional[float] = 8.5
fe_conc_avg: Optional[float] = 60.0
steel_grade: Optional[str] = "_default"
class ConsumptionRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
coil_weight_kg: float = Field(..., gt=0)
has_regen_station: Optional[bool] = True
fe_conc_avg: Optional[float] = 60.0
# ─────────────────────────────────────────────────────────────────────────────
# Calibration request schemas
# ─────────────────────────────────────────────────────────────────────────────
class AcidCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
steel_grade: str
acid_conc_list: List[float]
acid_temp_list: List[float]
scale_weight: Optional[float] = 8.5
actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min")
actual_quality_ok: bool = Field(..., description="该速度下质量是否合格")
note: Optional[str] = None
class TensionCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
yield_strength: float = Field(..., gt=0)
tension_coef: Optional[float] = 0.25
steel_grade: Optional[str] = "_default"
zone: str = Field(..., description="测量位置,如 s1_roller")
measured_kn: float = Field(..., gt=0, description="实测张力 kN")
note: Optional[str] = None
class QualityCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
avg_speed: float = Field(..., gt=0)
acid_conc_avg: float = Field(..., gt=0)
acid_temp_avg: float = Field(..., gt=0)
scale_weight: Optional[float] = 8.5
fe_conc_avg: Optional[float] = 60.0
steel_grade: Optional[str] = "_default"
actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C")
note: Optional[str] = None
# ─────────────────────────────────────────────────────────────────────────────
# Helper: append calibration history
# ─────────────────────────────────────────────────────────────────────────────
def _append_history(model_key: str, k_before, k_after,
input_data: dict, note: str = ""):
cal = _load_cal()
history = cal.get("history", [])
history.insert(0, {
"ts": datetime.now().isoformat(timespec="seconds"),
"model": model_key,
"k_before": k_before,
"k_after": k_after,
"input": input_data,
"note": note or "",
})
cal["history"] = history[:100]
_save_cal(cal)
# ─────────────────────────────────────────────────────────────────────────────
# Prediction endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.post("/acid-speed", response_model=Response[dict])
async def predict_acid_speed(body: AcidSpeedRequest, _=Depends(get_current_user)):
try:
model = AcidSpeedModel(
thickness=body.thickness, width=body.width,
steel_grade=body.steel_grade,
acid_conc_list=body.acid_conc_list,
acid_temp_list=body.acid_temp_list,
scale_weight=body.scale_weight, target_pi=body.target_pi,
)
result = model.calculate()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
return Response.ok(result)
@router.post("/tension", response_model=Response[dict])
async def predict_tension(body: TensionRequest, _=Depends(get_current_user)):
model = TensionModel(
thickness=body.thickness, width=body.width,
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
steel_grade=body.steel_grade,
)
return Response.ok(model.calculate())
@router.post("/quality", response_model=Response[dict])
async def predict_quality(body: QualityRequest, _=Depends(get_current_user)):
model = QualityPredictionModel(
thickness=body.thickness, avg_speed=body.avg_speed,
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
steel_grade=body.steel_grade,
)
return Response.ok(model.calculate())
@router.post("/consumption", response_model=Response[dict])
async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_user)):
model = AcidConsumptionModel(
thickness=body.thickness, width=body.width,
coil_weight_kg=body.coil_weight_kg,
has_regen_station=body.has_regen_station,
fe_conc_avg=body.fe_conc_avg,
)
return Response.ok(model.calculate())
# ─────────────────────────────────────────────────────────────────────────────
# Calibration endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.get("/calibration", response_model=Response[dict])
async def get_calibration(_=Depends(get_current_user)):
"""返回各模型当前校准系数(按钢种)和历史记录。"""
cal = _load_cal()
return Response.ok({
"kcal": cal.get("kcal", {}),
"phys": cal.get("phys", {}),
"history": cal.get("history", []),
})
@router.get("/calibration/samples", response_model=Response[dict])
async def get_calibration_samples(_=Depends(get_current_user)):
"""返回各模型 + 钢种的生产样本数量统计,以及重训所需的样本阈值。"""
stats = get_sample_stats()
return Response.ok({
"stats": stats,
"fit_threshold": 10,
"retrain_tip": "样本累积足够后POST /api/prediction/retrain 可触发 ONNX 重训",
})
@router.get("/calibration/phys-params", response_model=Response[dict])
async def get_phys_params(steel_grade: Optional[str] = None, _=Depends(get_current_user)):
"""查询物理参数EA_R / K0 / N_CONC可按钢种过滤。"""
cal = _load_cal()
phys = cal.get("phys", {})
if steel_grade:
result = {}
for m in ("acid_speed", "quality"):
result[m] = _get_phys(m, steel_grade)
return Response.ok({"steel_grade": steel_grade, "phys_params": result})
return Response.ok(phys)
@router.post("/calibration/acid-speed", response_model=Response[dict])
async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)):
"""录入实测速度,更新对应钢种 K_cal样本 ≥10 后自动触发物理参数拟合。"""
try:
model = AcidSpeedModel(
thickness=body.thickness, width=body.width,
steel_grade=body.steel_grade,
acid_conc_list=body.acid_conc_list,
acid_temp_list=body.acid_temp_list,
scale_weight=body.scale_weight,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
k_before = model.K_cal
predicted_speed = model.calculate()["max_speed"]
k_after = model.calibrate(
actual_max_speed=body.actual_max_speed,
actual_quality_ok=body.actual_quality_ok,
)
_append_history(
f"acid_speed[{body.steel_grade}]", k_before, k_after,
{"actual_speed": body.actual_max_speed,
"quality_ok": body.actual_quality_ok,
"predicted_speed": predicted_speed},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"k_before": k_before,
"k_after": k_after,
"predicted_speed": predicted_speed,
"adjustment": round((k_after / k_before - 1) * 100, 2),
})
@router.post("/calibration/tension", response_model=Response[dict])
async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)):
"""录入实测张力,更新对应钢种 + 区段的 K_cal。"""
model = TensionModel(
thickness=body.thickness, width=body.width,
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
steel_grade=body.steel_grade,
)
calc = model.calculate()
predicted_kn= calc["zones"].get(body.zone, {}).get("tension_kN", 0)
k_before = model.zone_kcal.get(body.zone, 1.0)
new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn)
k_after = new_zone_kcal.get(body.zone, 1.0)
_append_history(
f"tension[{body.steel_grade}][{body.zone}]", k_before, k_after,
{"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"zone": body.zone,
"k_before": k_before,
"k_after": k_after,
"predicted_kn": predicted_kn,
"measured_kn": body.measured_kn,
"adjustment": round((k_after / k_before - 1) * 100, 2),
"zone_kcal": new_zone_kcal,
})
@router.post("/calibration/quality", response_model=Response[dict])
async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)):
"""录入实际质检等级,更新对应钢种 K_cal样本 ≥10 后自动触发物理参数拟合。"""
model = QualityPredictionModel(
thickness=body.thickness, avg_speed=body.avg_speed,
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
steel_grade=body.steel_grade,
)
k_before = model.K_cal
calc = model.calculate()
predicted_grade = calc["overall_grade"]
k_after = model.calibrate(actual_grade=body.actual_grade)
_append_history(
f"quality[{body.steel_grade}]", k_before, k_after,
{"actual_grade": body.actual_grade, "predicted_grade": predicted_grade},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"k_before": k_before,
"k_after": k_after,
"predicted_grade": predicted_grade,
"actual_grade": body.actual_grade,
"adjustment": round((k_after / k_before - 1) * 100, 2),
})
@router.post("/calibration/fit-phys/{model_key}", response_model=Response[dict])
async def fit_phys_params_api(model_key: str, steel_grade: str, _=Depends(get_current_user)):
"""
手动触发指定模型 + 钢种的物理参数拟合(自动触发也会调用此逻辑)。
model_key: acid_speed | quality
"""
if model_key == "acid_speed":
result = fit_acid_phys_params(steel_grade)
elif model_key == "quality":
result = fit_quality_phys_params(steel_grade)
else:
raise HTTPException(status_code=404, detail="model_key 仅支持 acid_speed / quality")
if result is None:
from app.services.prediction import _FIT_MIN_SAMPLES
return Response.ok({
"fitted": False,
"reason": f"样本不足,需 ≥{_FIT_MIN_SAMPLES} 条,请继续录入校准数据",
})
return Response.ok({"fitted": True, "steel_grade": steel_grade, "phys_params": result})
@router.post("/calibration/reset/{model_key}", response_model=Response[dict])
async def reset_calibration(model_key: str, steel_grade: Optional[str] = None,
_=Depends(get_current_user)):
"""
重置校准系数为 1.0。
steel_grade 为空时重置该模型所有钢种;否则只重置指定钢种。
"""
cal = _load_cal()
if model_key == "tension":
for z in TENSION_ZONES:
key = f"tension_{z}"
if steel_grade:
cal.setdefault("kcal", {}).setdefault(key, {}).pop(steel_grade, None)
else:
cal.setdefault("kcal", {})[key] = {"_default": 1.0}
_append_history("tension", None, 1.0, {"action": "reset", "steel_grade": steel_grade})
elif model_key in ("acid_speed", "quality"):
if steel_grade:
cal.setdefault("kcal", {}).setdefault(model_key, {}).pop(steel_grade, None)
cal.setdefault("phys", {}).setdefault(model_key, {}).pop(steel_grade, None)
else:
cal.setdefault("kcal", {})[model_key] = {"_default": 1.0}
from app.services.prediction import _DEFAULT_PHYS
cal.setdefault("phys", {})[model_key] = {"_default": _DEFAULT_PHYS.copy()}
_append_history(model_key, None, 1.0,
{"action": "reset", "steel_grade": steel_grade or "all"})
else:
raise HTTPException(status_code=404, detail="未知模型")
_save_cal(cal)
return Response.ok({"model": model_key, "steel_grade": steel_grade or "all", "reset": True})
# ─────────────────────────────────────────────────────────────────────────────
# 数据飞轮ONNX 重训端点
# ─────────────────────────────────────────────────────────────────────────────
_retrain_lock = asyncio.Lock()
_retrain_status: dict = {"running": False, "last_ts": None, "last_result": None}
def _run_retrain():
"""在子进程中运行 train_models.py --use-real-data完成后热重载 ONNX。"""
global _retrain_status
_retrain_status["running"] = True
_retrain_status["last_ts"] = datetime.now().isoformat(timespec="seconds")
try:
result = subprocess.run(
[sys.executable, str(_BACKEND_DIR / "train_models.py"), "--use-real-data"],
capture_output=True, text=True, timeout=600,
)
ok = result.returncode == 0
_retrain_status["last_result"] = {
"success": ok,
"stdout": result.stdout[-2000:] if result.stdout else "",
"stderr": result.stderr[-1000:] if result.stderr else "",
}
if ok:
reload_onnx()
except subprocess.TimeoutExpired:
_retrain_status["last_result"] = {"success": False, "stderr": "训练超时(>600s"}
except Exception as e:
_retrain_status["last_result"] = {"success": False, "stderr": str(e)}
finally:
_retrain_status["running"] = False
@router.post("/retrain", response_model=Response[dict])
async def trigger_retrain(background_tasks: BackgroundTasks, _=Depends(get_current_user)):
"""
触发 ONNX 重训(使用 production_samples.jsonl 中的生产实绩)。
训练在后台子进程运行,完成后自动热重载模型。
"""
if _retrain_status["running"]:
raise HTTPException(status_code=409, detail="重训任务正在进行中,请稍后再试")
stats = get_sample_stats()
background_tasks.add_task(_run_retrain)
return Response.ok({
"message": "重训任务已启动,完成后 ONNX 将自动热重载",
"sample_stats": stats,
})
@router.get("/retrain/status", response_model=Response[dict])
async def retrain_status(_=Depends(get_current_user)):
"""查询重训任务状态。"""
return Response.ok(_retrain_status)