191 lines
6.0 KiB
Python
191 lines
6.0 KiB
Python
import os
|
|
import base64
|
|
import cv2
|
|
import json
|
|
import numpy as np
|
|
from typing import Dict, List
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
|
|
from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse
|
|
from app.detector import ModelManager
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(
|
|
title="Python Model Inference Service",
|
|
description="API for object detection using Python models",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Initialize model manager
|
|
model_manager = None
|
|
|
|
# Load models from configuration
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
global model_manager
|
|
model_manager = ModelManager()
|
|
|
|
# Look for models.json configuration file
|
|
models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models.json"))
|
|
|
|
if os.path.exists(models_json_path):
|
|
try:
|
|
with open(models_json_path, "r") as f:
|
|
models_config = json.load(f)
|
|
model_manager.load(models_config)
|
|
print(f"Loaded model configuration from {models_json_path}")
|
|
except Exception as e:
|
|
print(f"Failed to load models from {models_json_path}: {str(e)}")
|
|
else:
|
|
print(f"Models configuration not found: {models_json_path}")
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
global model_manager
|
|
if model_manager:
|
|
model_manager.close()
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/models", response_model=ModelsResponse)
|
|
async def get_models():
|
|
"""Get available models"""
|
|
global model_manager
|
|
|
|
if not model_manager:
|
|
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
|
|
|
detectors = model_manager.all()
|
|
models = []
|
|
|
|
for detector in detectors:
|
|
model_info = ModelInfo(
|
|
name=detector.model_name,
|
|
path=getattr(detector, 'model_path', ''),
|
|
size=[detector.input_width, detector.input_height],
|
|
backend="Python",
|
|
loaded=True
|
|
)
|
|
models.append(model_info)
|
|
|
|
return ModelsResponse(models=models)
|
|
|
|
@app.post("/api/detect", response_model=DetectionResponse)
|
|
async def detect(request: DetectionRequest):
|
|
"""Detect objects in an image"""
|
|
global model_manager
|
|
|
|
if not model_manager:
|
|
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
|
|
|
# Get detector for requested model
|
|
detector = model_manager.get(request.model_name)
|
|
if not detector:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}")
|
|
|
|
# Decode base64 image
|
|
try:
|
|
# Remove data URL prefix if present
|
|
if "base64," in request.image_data:
|
|
image_data = request.image_data.split("base64,")[1]
|
|
else:
|
|
image_data = request.image_data
|
|
|
|
# Decode base64 image
|
|
image_bytes = base64.b64decode(image_data)
|
|
nparr = np.frombuffer(image_bytes, np.uint8)
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
|
|
if image is None:
|
|
raise HTTPException(status_code=400, detail="Invalid image data")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}")
|
|
|
|
# Run detection
|
|
detections, inference_time = detector.detect(image)
|
|
|
|
return DetectionResponse(
|
|
model_name=request.model_name,
|
|
detections=detections,
|
|
inference_time=inference_time
|
|
)
|
|
|
|
@app.post("/api/detect/file", response_model=DetectionResponse)
|
|
async def detect_file(
|
|
model_name: str,
|
|
file: UploadFile = File(...)
|
|
):
|
|
"""Detect objects in an uploaded image file"""
|
|
print(f"接收到的 model_name: {model_name}")
|
|
print(f"文件名: {file.filename}")
|
|
print(f"文件内容类型: {file.content_type}")
|
|
|
|
global model_manager
|
|
|
|
if not model_manager:
|
|
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
|
|
|
# Get detector for requested model
|
|
detector = model_manager.get(model_name)
|
|
if not detector:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
|
|
|
# Read uploaded file
|
|
try:
|
|
contents = await file.read()
|
|
print(f"文件大小: {len(contents)} 字节")
|
|
|
|
if len(contents) == 0:
|
|
raise HTTPException(status_code=400, detail="Empty file")
|
|
|
|
nparr = np.frombuffer(contents, np.uint8)
|
|
print(f"numpy数组形状: {nparr.shape}, dtype: {nparr.dtype}")
|
|
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
|
|
if image is None:
|
|
print("错误: cv2.imdecode 返回 None")
|
|
raise HTTPException(status_code=400, detail="Invalid image data - failed to decode")
|
|
|
|
print(f"解码后图像形状: {image.shape}, dtype: {image.dtype}")
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
print(f"处理图像时出错: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
|
|
|
|
# Run detection
|
|
try:
|
|
detections, inference_time = detector.detect(image)
|
|
print(f"检测完成: 找到 {len(detections)} 个目标, 耗时 {inference_time:.2f}ms")
|
|
except Exception as e:
|
|
print(f"推理过程中出错: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}")
|
|
|
|
return DetectionResponse(
|
|
model_name=model_name,
|
|
detections=detections,
|
|
inference_time=inference_time
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) |