from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from typing import List, Optional from datetime import datetime from app.schemas.common import Response from app.services.auth_service import get_current_user from app.services.prediction import ( AcidSpeedModel, TensionModel, QualityPredictionModel, AcidConsumptionModel, _load_cal, _save_cal, ) router = APIRouter() # ───────────────────────────────────────────────────────────────────────────── # Prediction request schemas # ───────────────────────────────────────────────────────────────────────────── class AcidSpeedRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) steel_grade: str acid_conc_list: List[float] acid_temp_list: List[float] scale_weight: Optional[float] = 8.5 target_pi: Optional[float] = 95.0 class TensionRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) yield_strength: float = Field(..., gt=0) tension_coef: Optional[float] = 0.25 class QualityRequest(BaseModel): thickness: float = Field(..., gt=0) avg_speed: float = Field(..., gt=0) acid_conc_avg: float = Field(..., gt=0) acid_temp_avg: float = Field(..., gt=0) scale_weight: Optional[float] = 8.5 class ConsumptionRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) coil_weight_kg: float = Field(..., gt=0) has_regen_station: Optional[bool] = True # ───────────────────────────────────────────────────────────────────────────── # Calibration request schemas # ───────────────────────────────────────────────────────────────────────────── class AcidCalibRequest(BaseModel): # 需要重建模型实例的上下文参数 thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) steel_grade: str acid_conc_list: List[float] acid_temp_list: List[float] scale_weight: Optional[float] = 8.5 # 校准输入 actual_max_speed: float = Field(..., gt=0, description="实测质量合格时的最高速度 m/min") actual_quality_ok: bool = Field(..., description="该速度下质量是否合格") note: Optional[str] = None class TensionCalibRequest(BaseModel): thickness: float = Field(..., gt=0) width: float = Field(..., gt=0) yield_strength: float = Field(..., gt=0) tension_coef: Optional[float] = 0.25 zone: str = Field(..., description="测量位置,如 s1_roller") measured_kn: float = Field(..., gt=0, description="实测张力 kN") note: Optional[str] = None class QualityCalibRequest(BaseModel): thickness: float = Field(..., gt=0) avg_speed: float = Field(..., gt=0) acid_conc_avg: float = Field(..., gt=0) acid_temp_avg: float = Field(..., gt=0) scale_weight: Optional[float] = 8.5 actual_grade: str = Field(..., description="实际质检等级 A1/A2/B1/B2/C") note: Optional[str] = None # ───────────────────────────────────────────────────────────────────────────── # Helper: append calibration history # ───────────────────────────────────────────────────────────────────────────── def _append_history(model_key: str, k_before: float, k_after: float, input_data: dict, note: str = ""): cal = _load_cal() history = cal.get("history", []) history.insert(0, { "ts": datetime.now().isoformat(timespec="seconds"), "model": model_key, "k_before": k_before, "k_after": k_after, "input": input_data, "note": note or "", }) cal["history"] = history[:100] _save_cal(cal) # ───────────────────────────────────────────────────────────────────────────── # Prediction endpoints # ───────────────────────────────────────────────────────────────────────────── @router.post("/acid-speed", response_model=Response[dict]) async def predict_acid_speed(body: AcidSpeedRequest, _=Depends(get_current_user)): try: model = AcidSpeedModel( thickness=body.thickness, width=body.width, steel_grade=body.steel_grade, acid_conc_list=body.acid_conc_list, acid_temp_list=body.acid_temp_list, scale_weight=body.scale_weight, target_pi=body.target_pi, ) result = model.calculate() except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) return Response.ok(result) @router.post("/tension", response_model=Response[dict]) async def predict_tension(body: TensionRequest, _=Depends(get_current_user)): model = TensionModel( thickness=body.thickness, width=body.width, yield_strength=body.yield_strength, tension_coef=body.tension_coef, ) return Response.ok(model.calculate()) @router.post("/quality", response_model=Response[dict]) async def predict_quality(body: QualityRequest, _=Depends(get_current_user)): model = QualityPredictionModel( thickness=body.thickness, avg_speed=body.avg_speed, acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg, scale_weight=body.scale_weight, ) return Response.ok(model.calculate()) @router.post("/consumption", response_model=Response[dict]) async def predict_consumption(body: ConsumptionRequest, _=Depends(get_current_user)): model = AcidConsumptionModel( thickness=body.thickness, width=body.width, coil_weight_kg=body.coil_weight_kg, has_regen_station=body.has_regen_station, ) return Response.ok(model.calculate()) # ───────────────────────────────────────────────────────────────────────────── # Calibration endpoints # ───────────────────────────────────────────────────────────────────────────── TENSION_ZONES = ["inlet","s1_roller","acid_entry","acid1","acid2","acid3","rinse","leveler","s2_roller","outlet"] @router.get("/calibration", response_model=Response[dict]) async def get_calibration(_=Depends(get_current_user)): """返回各模型当前校准系数和历史记录""" cal = _load_cal() tension_zone_kcal = { z: cal.get(f"tension_zone_{z}", 1.0) for z in TENSION_ZONES } return Response.ok({ "acid_speed_kcal": cal.get("acid_speed_kcal", 1.0), "tension_zone_kcal": tension_zone_kcal, "quality_kcal": cal.get("quality_kcal", 1.0), "history": cal.get("history", []), }) @router.post("/calibration/acid-speed", response_model=Response[dict]) async def calibrate_acid_speed(body: AcidCalibRequest, _=Depends(get_current_user)): """录入实测数据,更新酸洗速度模型校准系数""" try: model = AcidSpeedModel( thickness=body.thickness, width=body.width, steel_grade=body.steel_grade, acid_conc_list=body.acid_conc_list, acid_temp_list=body.acid_temp_list, scale_weight=body.scale_weight, ) except ValueError as e: raise HTTPException(status_code=422, detail=str(e)) k_before = model.K_cal predicted_speed = model.calculate()["max_speed"] k_after = model.calibrate( actual_max_speed=body.actual_max_speed, actual_quality_ok=body.actual_quality_ok, ) _append_history( "acid_speed", k_before, k_after, {"actual_speed": body.actual_max_speed, "quality_ok": body.actual_quality_ok, "predicted_speed": predicted_speed}, body.note or "", ) return Response.ok({ "k_before": k_before, "k_after": k_after, "predicted_speed": predicted_speed, "adjustment": round((k_after / k_before - 1) * 100, 2), }) @router.post("/calibration/tension", response_model=Response[dict]) async def calibrate_tension(body: TensionCalibRequest, _=Depends(get_current_user)): """录入实测张力,仅更新指定区段的校准系数""" model = TensionModel( thickness=body.thickness, width=body.width, yield_strength=body.yield_strength, tension_coef=body.tension_coef, ) calc = model.calculate() predicted_kn = calc["zones"].get(body.zone, {}).get("tension_kN", 0) k_before = model.zone_kcal.get(body.zone, 1.0) new_zone_kcal = model.calibrate(zone=body.zone, measured_kn=body.measured_kn) k_after = new_zone_kcal.get(body.zone, 1.0) _append_history( "tension", k_before, k_after, {"zone": body.zone, "measured_kn": body.measured_kn, "predicted_kn": predicted_kn}, body.note or "", ) return Response.ok({ "zone": body.zone, "k_before": k_before, "k_after": k_after, "predicted_kn": predicted_kn, "measured_kn": body.measured_kn, "adjustment": round((k_after / k_before - 1) * 100, 2), "zone_kcal": new_zone_kcal, }) @router.post("/calibration/quality", response_model=Response[dict]) async def calibrate_quality(body: QualityCalibRequest, _=Depends(get_current_user)): """录入实际质检等级,更新质量模型校准系数""" model = QualityPredictionModel( thickness=body.thickness, avg_speed=body.avg_speed, acid_conc_avg=body.acid_conc_avg, acid_temp_avg=body.acid_temp_avg, scale_weight=body.scale_weight, ) k_before = model.K_cal calc = model.calculate() predicted_grade = calc["overall_grade"] k_after = model.calibrate(actual_grade=body.actual_grade) _append_history( "quality", k_before, k_after, {"actual_grade": body.actual_grade, "predicted_grade": predicted_grade}, body.note or "", ) return Response.ok({ "k_before": k_before, "k_after": k_after, "predicted_grade": predicted_grade, "actual_grade": body.actual_grade, "adjustment": round((k_after / k_before - 1) * 100, 2), }) @router.post("/calibration/reset/{model_key}", response_model=Response[dict]) async def reset_calibration(model_key: str, _=Depends(get_current_user)): """将指定模型的校准系数全部重置为 1.0""" cal = _load_cal() if model_key == "tension": # 重置所有区段 for z in TENSION_ZONES: cal[f"tension_zone_{z}"] = 1.0 _append_history("tension", None, 1.0, {"action": "reset_all_zones"}) elif model_key in ("acid_speed", "quality"): key = f"{model_key}_kcal" k_before = cal.get(key, 1.0) cal[key] = 1.0 _append_history(model_key, k_before, 1.0, {"action": "reset"}) else: raise HTTPException(status_code=404, detail="未知模型") _save_cal(cal) return Response.ok({"model": model_key, "reset": True})