Files
pickling-mes/backend/app/api/prediction.py

427 lines
18 KiB
Python
Raw Normal View History

import asyncio
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, Field
from app.schemas.common import Response
from app.services.auth_service import get_current_user
from app.services.prediction import (
AcidSpeedModel,
TensionModel,
QualityPredictionModel,
AcidConsumptionModel,
_load_cal,
_save_cal,
_get_phys,
append_sample,
get_sample_stats,
fit_acid_phys_params,
fit_quality_phys_params,
reload_onnx,
)
router = APIRouter()
TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3",
"rinse","leveler","s2_roller","outlet"]
_BACKEND_DIR = Path(__file__).parent.parent.parent
# ─────────────────────────────────────────────────────────────────────────────
# Prediction request schemas
# ─────────────────────────────────────────────────────────────────────────────
class AcidSpeedRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
steel_grade: str
acid_conc_list: List[float]
acid_temp_list: List[float]
scale_weight: Optional[float] = 8.5
target_pi: Optional[float] = 95.0
class TensionRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
yield_strength: float = Field(..., gt=0)
tension_coef: Optional[float] = 0.25
steel_grade: Optional[str] = "_default"
class QualityRequest(BaseModel):
thickness: float = Field(..., gt=0)
avg_speed: float = Field(..., gt=0)
acid_conc_avg: float = Field(..., gt=0)
acid_temp_avg: float = Field(..., gt=0)
scale_weight: Optional[float] = 8.5
fe_conc_avg: Optional[float] = 60.0
steel_grade: Optional[str] = "_default"
class ConsumptionRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
coil_weight_kg: float = Field(..., gt=0)
has_regen_station: Optional[bool] = True
fe_conc_avg: Optional[float] = 60.0
# ─────────────────────────────────────────────────────────────────────────────
# Calibration request schemas
# ─────────────────────────────────────────────────────────────────────────────
class AcidCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
steel_grade: str
acid_conc_list: List[float]
acid_temp_list: List[float]
scale_weight: Optional[float] = 8.5
actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min")
actual_quality_ok: bool = Field(..., description="该速度下质量是否合格")
note: Optional[str] = None
class TensionCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
width: float = Field(..., gt=0)
yield_strength: float = Field(..., gt=0)
tension_coef: Optional[float] = 0.25
steel_grade: Optional[str] = "_default"
zone: str = Field(..., description="测量位置,如 s1_roller")
measured_kn: float = Field(..., gt=0, description="实测张力 kN")
note: Optional[str] = None
class QualityCalibRequest(BaseModel):
thickness: float = Field(..., gt=0)
avg_speed: float = Field(..., gt=0)
acid_conc_avg: float = Field(..., gt=0)
acid_temp_avg: float = Field(..., gt=0)
scale_weight: Optional[float] = 8.5
fe_conc_avg: Optional[float] = 60.0
steel_grade: Optional[str] = "_default"
actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C")
note: Optional[str] = None
# ─────────────────────────────────────────────────────────────────────────────
# Helper: append calibration history
# ─────────────────────────────────────────────────────────────────────────────
def _append_history(model_key: str, k_before, k_after,
input_data: dict, note: str = ""):
cal = _load_cal()
history = cal.get("history", [])
history.insert(0, {
"ts": datetime.now().isoformat(timespec="seconds"),
"model": model_key,
"k_before": k_before,
"k_after": k_after,
"input": input_data,
"note": note or "",
})
cal["history"] = history[:100]
_save_cal(cal)
# ─────────────────────────────────────────────────────────────────────────────
# Prediction endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.post("/acid-speed", response_model=Response[dict])
async def predict_acid_speed(body: AcidSpeedRequest, _=Depends(get_current_user)):
try:
model = AcidSpeedModel(
thickness=body.thickness, width=body.width,
steel_grade=body.steel_grade,
acid_conc_list=body.acid_conc_list,
acid_temp_list=body.acid_temp_list,
scale_weight=body.scale_weight, target_pi=body.target_pi,
)
result = model.calculate()
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
return Response.ok(result)
@router.post("/tension", response_model=Response[dict])
async def predict_tension(body: TensionRequest, _=Depends(get_current_user)):
model = TensionModel(
thickness=body.thickness, width=body.width,
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
steel_grade=body.steel_grade,
)
return Response.ok(model.calculate())
@router.post("/quality", response_model=Response[dict])
async def predict_quality(body: QualityRequest, _=Depends(get_current_user)):
model = QualityPredictionModel(
thickness=body.thickness, avg_speed=body.avg_speed,
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
steel_grade=body.steel_grade,
)
return Response.ok(model.calculate())
@router.post("/consumption", response_model=Response[dict])
async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_user)):
model = AcidConsumptionModel(
thickness=body.thickness, width=body.width,
coil_weight_kg=body.coil_weight_kg,
has_regen_station=body.has_regen_station,
fe_conc_avg=body.fe_conc_avg,
)
return Response.ok(model.calculate())
# ─────────────────────────────────────────────────────────────────────────────
# Calibration endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.get("/calibration", response_model=Response[dict])
async def get_calibration(_=Depends(get_current_user)):
"""返回各模型当前校准系数(按钢种)和历史记录。"""
cal = _load_cal()
return Response.ok({
"kcal": cal.get("kcal", {}),
"phys": cal.get("phys", {}),
"history": cal.get("history", []),
})
@router.get("/calibration/samples", response_model=Response[dict])
async def get_calibration_samples(_=Depends(get_current_user)):
"""返回各模型 + 钢种的生产样本数量统计,以及重训所需的样本阈值。"""
stats = get_sample_stats()
return Response.ok({
"stats": stats,
"fit_threshold": 10,
"retrain_tip": "样本累积足够后POST /api/prediction/retrain 可触发 ONNX 重训",
})
@router.get("/calibration/phys-params", response_model=Response[dict])
async def get_phys_params(steel_grade: Optional[str] = None, _=Depends(get_current_user)):
"""查询物理参数EA_R / K0 / N_CONC可按钢种过滤。"""
cal = _load_cal()
phys = cal.get("phys", {})
if steel_grade:
result = {}
for m in ("acid_speed", "quality"):
result[m] = _get_phys(m, steel_grade)
return Response.ok({"steel_grade": steel_grade, "phys_params": result})
return Response.ok(phys)
@router.post("/calibration/acid-speed", response_model=Response[dict])
async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)):
"""录入实测速度,更新对应钢种 K_cal样本 ≥10 后自动触发物理参数拟合。"""
try:
model = AcidSpeedModel(
thickness=body.thickness, width=body.width,
steel_grade=body.steel_grade,
acid_conc_list=body.acid_conc_list,
acid_temp_list=body.acid_temp_list,
scale_weight=body.scale_weight,
)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e))
k_before = model.K_cal
predicted_speed = model.calculate()["max_speed"]
k_after = model.calibrate(
actual_max_speed=body.actual_max_speed,
actual_quality_ok=body.actual_quality_ok,
)
_append_history(
f"acid_speed[{body.steel_grade}]", k_before, k_after,
{"actual_speed": body.actual_max_speed,
"quality_ok": body.actual_quality_ok,
"predicted_speed": predicted_speed},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"k_before": k_before,
"k_after": k_after,
"predicted_speed": predicted_speed,
"adjustment": round((k_after / k_before - 1) * 100, 2),
})
@router.post("/calibration/tension", response_model=Response[dict])
async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)):
"""录入实测张力,更新对应钢种 + 区段的 K_cal。"""
model = TensionModel(
thickness=body.thickness, width=body.width,
yield_strength=body.yield_strength, tension_coef=body.tension_coef,
steel_grade=body.steel_grade,
)
calc = model.calculate()
predicted_kn= calc["zones"].get(body.zone, {}).get("tension_kN", 0)
k_before = model.zone_kcal.get(body.zone, 1.0)
new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn)
k_after = new_zone_kcal.get(body.zone, 1.0)
_append_history(
f"tension[{body.steel_grade}][{body.zone}]", k_before, k_after,
{"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"zone": body.zone,
"k_before": k_before,
"k_after": k_after,
"predicted_kn": predicted_kn,
"measured_kn": body.measured_kn,
"adjustment": round((k_after / k_before - 1) * 100, 2),
"zone_kcal": new_zone_kcal,
})
@router.post("/calibration/quality", response_model=Response[dict])
async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)):
"""录入实际质检等级,更新对应钢种 K_cal样本 ≥10 后自动触发物理参数拟合。"""
model = QualityPredictionModel(
thickness=body.thickness, avg_speed=body.avg_speed,
acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg,
scale_weight=body.scale_weight, fe_conc_avg=body.fe_conc_avg,
steel_grade=body.steel_grade,
)
k_before = model.K_cal
calc = model.calculate()
predicted_grade = calc["overall_grade"]
k_after = model.calibrate(actual_grade=body.actual_grade)
_append_history(
f"quality[{body.steel_grade}]", k_before, k_after,
{"actual_grade": body.actual_grade, "predicted_grade": predicted_grade},
body.note or "",
)
return Response.ok({
"steel_grade": body.steel_grade,
"k_before": k_before,
"k_after": k_after,
"predicted_grade": predicted_grade,
"actual_grade": body.actual_grade,
"adjustment": round((k_after / k_before - 1) * 100, 2),
})
@router.post("/calibration/fit-phys/{model_key}", response_model=Response[dict])
async def fit_phys_params_api(model_key: str, steel_grade: str, _=Depends(get_current_user)):
"""
手动触发指定模型 + 钢种的物理参数拟合自动触发也会调用此逻辑
model_key: acid_speed | quality
"""
if model_key == "acid_speed":
result = fit_acid_phys_params(steel_grade)
elif model_key == "quality":
result = fit_quality_phys_params(steel_grade)
else:
raise HTTPException(status_code=404, detail="model_key 仅支持 acid_speed / quality")
if result is None:
from app.services.prediction import _FIT_MIN_SAMPLES
return Response.ok({
"fitted": False,
"reason": f"样本不足,需 ≥{_FIT_MIN_SAMPLES} 条,请继续录入校准数据",
})
return Response.ok({"fitted": True, "steel_grade": steel_grade, "phys_params": result})
@router.post("/calibration/reset/{model_key}", response_model=Response[dict])
async def reset_calibration(model_key: str, steel_grade: Optional[str] = None,
_=Depends(get_current_user)):
"""
重置校准系数为 1.0
steel_grade 为空时重置该模型所有钢种否则只重置指定钢种
"""
cal = _load_cal()
if model_key == "tension":
for z in TENSION_ZONES:
key = f"tension_{z}"
if steel_grade:
cal.setdefault("kcal", {}).setdefault(key, {}).pop(steel_grade, None)
else:
cal.setdefault("kcal", {})[key] = {"_default": 1.0}
_append_history("tension", None, 1.0, {"action": "reset", "steel_grade": steel_grade})
elif model_key in ("acid_speed", "quality"):
if steel_grade:
cal.setdefault("kcal", {}).setdefault(model_key, {}).pop(steel_grade, None)
cal.setdefault("phys", {}).setdefault(model_key, {}).pop(steel_grade, None)
else:
cal.setdefault("kcal", {})[model_key] = {"_default": 1.0}
from app.services.prediction import _DEFAULT_PHYS
cal.setdefault("phys", {})[model_key] = {"_default": _DEFAULT_PHYS.copy()}
_append_history(model_key, None, 1.0,
{"action": "reset", "steel_grade": steel_grade or "all"})
else:
raise HTTPException(status_code=404, detail="未知模型")
_save_cal(cal)
return Response.ok({"model": model_key, "steel_grade": steel_grade or "all", "reset": True})
# ─────────────────────────────────────────────────────────────────────────────
# 数据飞轮ONNX 重训端点
# ─────────────────────────────────────────────────────────────────────────────
_retrain_lock = asyncio.Lock()
_retrain_status: dict = {"running": False, "last_ts": None, "last_result": None}
def _run_retrain():
"""在子进程中运行 train_models.py --use-real-data完成后热重载 ONNX。"""
global _retrain_status
_retrain_status["running"] = True
_retrain_status["last_ts"] = datetime.now().isoformat(timespec="seconds")
try:
result = subprocess.run(
[sys.executable, str(_BACKEND_DIR / "train_models.py"), "--use-real-data"],
capture_output=True, text=True, timeout=600,
)
ok = result.returncode == 0
_retrain_status["last_result"] = {
"success": ok,
"stdout": result.stdout[-2000:] if result.stdout else "",
"stderr": result.stderr[-1000:] if result.stderr else "",
}
if ok:
reload_onnx()
except subprocess.TimeoutExpired:
_retrain_status["last_result"] = {"success": False, "stderr": "训练超时(>600s"}
except Exception as e:
_retrain_status["last_result"] = {"success": False, "stderr": str(e)}
finally:
_retrain_status["running"] = False
@router.post("/retrain", response_model=Response[dict])
async def trigger_retrain(background_tasks: BackgroundTasks, _=Depends(get_current_user)):
"""
触发 ONNX 重训使用 production_samples.jsonl 中的生产实绩
训练在后台子进程运行完成后自动热重载模型
"""
if _retrain_status["running"]:
raise HTTPException(status_code=409, detail="重训任务正在进行中,请稍后再试")
stats = get_sample_stats()
background_tasks.add_task(_run_retrain)
return Response.ok({
"message": "重训任务已启动,完成后 ONNX 将自动热重载",
"sample_stats": stats,
})
@router.get("/retrain/status", response_model=Response[dict])
async def retrain_status(_=Depends(get_current_user)):
"""查询重训任务状态。"""
return Response.ok(_retrain_status)