import os import base64 import cv2 import json import numpy as np from typing import Dict, List from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import uvicorn from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse from app.detector import ModelManager # Initialize FastAPI app app = FastAPI( title="Python Model Inference Service", description="API for object detection using Python models", version="1.0.0" ) # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize model manager model_manager = None # Load models from configuration @app.on_event("startup") async def startup_event(): global model_manager model_manager = ModelManager() # Look for models.json configuration file models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models.json")) if os.path.exists(models_json_path): try: with open(models_json_path, "r") as f: models_config = json.load(f) model_manager.load(models_config) print(f"Loaded model configuration from {models_json_path}") except Exception as e: print(f"Failed to load models from {models_json_path}: {str(e)}") else: print(f"Models configuration not found: {models_json_path}") @app.on_event("shutdown") async def shutdown_event(): global model_manager if model_manager: model_manager.close() @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "ok"} @app.get("/api/models", response_model=ModelsResponse) async def get_models(): """Get available models""" global model_manager if not model_manager: raise HTTPException(status_code=500, detail="Model manager not initialized") detectors = model_manager.all() models = [] for detector in detectors: model_info = ModelInfo( name=detector.model_name, path=getattr(detector, 'model_path', ''), size=[detector.input_width, detector.input_height], backend="Python", loaded=True ) models.append(model_info) return ModelsResponse(models=models) @app.post("/api/detect", response_model=DetectionResponse) async def detect(request: DetectionRequest): """Detect objects in an image""" global model_manager if not model_manager: raise HTTPException(status_code=500, detail="Model manager not initialized") # Get detector for requested model detector = model_manager.get(request.model_name) if not detector: raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}") # Decode base64 image try: # Remove data URL prefix if present if "base64," in request.image_data: image_data = request.image_data.split("base64,")[1] else: image_data = request.image_data # Decode base64 image image_bytes = base64.b64decode(image_data) nparr = np.frombuffer(image_bytes, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: raise HTTPException(status_code=400, detail="Invalid image data") except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}") # Run detection detections, inference_time = detector.detect(image) return DetectionResponse( model_name=request.model_name, detections=detections, inference_time=inference_time ) @app.post("/api/detect/file", response_model=DetectionResponse) async def detect_file( model_name: str, file: UploadFile = File(...) ): """Detect objects in an uploaded image file""" print(f"接收到的 model_name: {model_name}") print(f"文件名: {file.filename}") print(f"文件内容类型: {file.content_type}") global model_manager if not model_manager: raise HTTPException(status_code=500, detail="Model manager not initialized") # Get detector for requested model detector = model_manager.get(model_name) if not detector: raise HTTPException(status_code=404, detail=f"Model not found: {model_name}") # Read uploaded file try: contents = await file.read() print(f"文件大小: {len(contents)} 字节") if len(contents) == 0: raise HTTPException(status_code=400, detail="Empty file") nparr = np.frombuffer(contents, np.uint8) print(f"numpy数组形状: {nparr.shape}, dtype: {nparr.dtype}") image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: print("错误: cv2.imdecode 返回 None") raise HTTPException(status_code=400, detail="Invalid image data - failed to decode") print(f"解码后图像形状: {image.shape}, dtype: {image.dtype}") except HTTPException: raise except Exception as e: print(f"处理图像时出错: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}") # Run detection try: detections, inference_time = detector.detect(image) print(f"检测完成: 找到 {len(detections)} 个目标, 耗时 {inference_time:.2f}ms") except Exception as e: print(f"推理过程中出错: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}") return DetectionResponse( model_name=model_name, detections=detections, inference_time=inference_time ) if __name__ == "__main__": uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)