Files

191 lines
6.0 KiB
Python
Raw Permalink Normal View History

2025-09-30 14:23:33 +08:00
import os
import base64
import cv2
import json
import numpy as np
from typing import Dict, List
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse
from app.detector import ModelManager
# Initialize FastAPI app
app = FastAPI(
title="Python Model Inference Service",
description="API for object detection using Python models",
version="1.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model manager
model_manager = None
# Load models from configuration
@app.on_event("startup")
async def startup_event():
global model_manager
model_manager = ModelManager()
# Look for models.json configuration file
2025-10-07 15:49:58 +08:00
models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models.json"))
2025-09-30 14:23:33 +08:00
if os.path.exists(models_json_path):
try:
with open(models_json_path, "r") as f:
models_config = json.load(f)
model_manager.load(models_config)
print(f"Loaded model configuration from {models_json_path}")
except Exception as e:
print(f"Failed to load models from {models_json_path}: {str(e)}")
else:
print(f"Models configuration not found: {models_json_path}")
@app.on_event("shutdown")
async def shutdown_event():
global model_manager
if model_manager:
model_manager.close()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/api/models", response_model=ModelsResponse)
async def get_models():
"""Get available models"""
global model_manager
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
detectors = model_manager.all()
models = []
for detector in detectors:
model_info = ModelInfo(
name=detector.model_name,
path=getattr(detector, 'model_path', ''),
size=[detector.input_width, detector.input_height],
backend="Python",
loaded=True
)
models.append(model_info)
return ModelsResponse(models=models)
@app.post("/api/detect", response_model=DetectionResponse)
async def detect(request: DetectionRequest):
"""Detect objects in an image"""
global model_manager
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
# Get detector for requested model
detector = model_manager.get(request.model_name)
if not detector:
raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}")
# Decode base64 image
try:
# Remove data URL prefix if present
if "base64," in request.image_data:
image_data = request.image_data.split("base64,")[1]
else:
image_data = request.image_data
# Decode base64 image
image_bytes = base64.b64decode(image_data)
nparr = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image data")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}")
# Run detection
detections, inference_time = detector.detect(image)
return DetectionResponse(
model_name=request.model_name,
detections=detections,
inference_time=inference_time
)
@app.post("/api/detect/file", response_model=DetectionResponse)
async def detect_file(
model_name: str,
file: UploadFile = File(...)
):
2025-10-08 11:51:28 +08:00
"""Detect objects in an uploaded image file"""
2025-10-08 10:00:36 +08:00
print(f"接收到的 model_name: {model_name}")
print(f"文件名: {file.filename}")
2025-10-08 11:51:28 +08:00
print(f"文件内容类型: {file.content_type}")
2025-09-30 14:23:33 +08:00
global model_manager
2025-10-08 10:00:36 +08:00
2025-09-30 14:23:33 +08:00
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
# Get detector for requested model
detector = model_manager.get(model_name)
if not detector:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
# Read uploaded file
try:
contents = await file.read()
2025-10-08 11:51:28 +08:00
print(f"文件大小: {len(contents)} 字节")
if len(contents) == 0:
raise HTTPException(status_code=400, detail="Empty file")
2025-09-30 14:23:33 +08:00
nparr = np.frombuffer(contents, np.uint8)
2025-10-08 11:51:28 +08:00
print(f"numpy数组形状: {nparr.shape}, dtype: {nparr.dtype}")
2025-09-30 14:23:33 +08:00
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
2025-10-08 11:51:28 +08:00
print("错误: cv2.imdecode 返回 None")
raise HTTPException(status_code=400, detail="Invalid image data - failed to decode")
print(f"解码后图像形状: {image.shape}, dtype: {image.dtype}")
except HTTPException:
raise
2025-09-30 14:23:33 +08:00
except Exception as e:
2025-10-08 11:51:28 +08:00
print(f"处理图像时出错: {str(e)}")
import traceback
traceback.print_exc()
2025-09-30 14:23:33 +08:00
raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
# Run detection
2025-10-08 11:51:28 +08:00
try:
detections, inference_time = detector.detect(image)
print(f"检测完成: 找到 {len(detections)} 个目标, 耗时 {inference_time:.2f}ms")
except Exception as e:
print(f"推理过程中出错: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}")
2025-09-30 14:23:33 +08:00
return DetectionResponse(
model_name=model_name,
detections=detections,
inference_time=inference_time
)
if __name__ == "__main__":
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)