将检测任务迁移python

This commit is contained in:
2025-09-30 14:23:33 +08:00
parent 3fe5f8083d
commit 39d39a7a24
69 changed files with 7921 additions and 1836 deletions

View File

@@ -0,0 +1 @@
# Python Inference Service package

View File

@@ -0,0 +1,311 @@
import os
import cv2
import numpy as np
import time
from typing import List, Dict, Tuple, Optional
import importlib.util
import sys
from app.models import Detection
class PythonModelDetector:
"""Object detector using native Python models"""
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
"""
Initialize detector with Python model
Args:
model_name: Name of the model
model_path: Path to the Python model file (.py)
input_width: Input width for the model
input_height: Input height for the model
color: RGB color for detection boxes (default: green)
"""
self.model_name = model_name
self.input_width = input_width
self.input_height = input_height
self.color = color
# Convert color from RGB to BGR (OpenCV uses BGR)
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
# Default confidence thresholds
self.conf_threshold = 0.25
self.nms_threshold = 0.45
# Load the Python model dynamically
self._load_python_model(model_path)
# Load class names if available
self.classes = []
model_dir = os.path.dirname(model_path)
classes_path = os.path.join(model_dir, "classes.txt")
if os.path.exists(classes_path):
with open(classes_path, 'r') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
def _load_python_model(self, model_path: str):
"""Load Python model dynamically"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Get model directory and file name
model_dir = os.path.dirname(model_path)
model_file = os.path.basename(model_path)
model_name = os.path.splitext(model_file)[0]
# Add model directory to system path
if model_dir not in sys.path:
sys.path.append(model_dir)
# Import the model module
spec = importlib.util.spec_from_file_location(model_name, model_path)
if spec is None:
raise ImportError(f"Failed to load model specification: {model_path}")
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Check if the module has the required interface
if not hasattr(model_module, "Model"):
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
# Create model instance
self.model = model_module.Model()
# Check if model has the required methods
if not hasattr(self.model, "predict"):
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
def preprocess(self, img: np.ndarray) -> np.ndarray:
"""Preprocess image for model input"""
# Ensure BGR image
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # BGRA
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
# Resize to model input size
resized = cv2.resize(img, (self.input_width, self.input_height))
# Use model's preprocess method if available
if hasattr(self.model, "preprocess"):
return self.model.preprocess(resized)
# Default preprocessing: normalize to [0, 1]
return resized / 255.0
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
"""
Detect objects in an image
Args:
img: Input image in BGR format (OpenCV)
Returns:
List of Detection objects and inference time in milliseconds
"""
if img is None or img.size == 0:
return [], 0.0
# Original image dimensions
img_height, img_width = img.shape[:2]
# Preprocess image
processed_img = self.preprocess(img)
# Measure inference time
start_time = time.time()
try:
# Run inference using model's predict method
# Expected return format from model's predict:
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
# bbox: (x, y, w, h) normalized [0-1]
model_results = self.model.predict(processed_img)
# Calculate inference time in milliseconds
inference_time = (time.time() - start_time) * 1000
# Convert model results to Detection objects
detections = []
for result in model_results:
# Skip low confidence detections
confidence = result.get('confidence', 0)
if confidence < self.conf_threshold:
continue
# Get bounding box (normalized coordinates)
bbox = result.get('bbox', [0, 0, 0, 0])
# Denormalize bbox to image coordinates
x = int(bbox[0] * img_width)
y = int(bbox[1] * img_height)
w = int(bbox[2] * img_width)
h = int(bbox[3] * img_height)
# Skip invalid boxes
if w <= 0 or h <= 0:
continue
# Get class ID and name
class_id = result.get('class_id', 0)
class_name = f"cls{class_id}"
if 0 <= class_id < len(self.classes):
class_name = self.classes[class_id]
# Create Detection object
label = f"[{self.model_name}] {class_name}"
detection = Detection(
label=label,
confidence=confidence,
x=x,
y=y,
width=w,
height=h,
color=self.color
)
detections.append(detection)
# Apply NMS if model doesn't do it internally
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
return detections, inference_time
else:
# Convert detections to boxes and scores
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
scores = [d.confidence for d in detections]
if boxes:
# Apply NMS
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
detections = [detections[i] for i in indices]
return detections, inference_time
except Exception as e:
print(f"Error during detection: {str(e)}")
return [], (time.time() - start_time) * 1000
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
# Sort by score in descending order
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
keep = []
while indices:
# Get index with highest score
current = indices.pop(0)
keep.append(current)
# No more indices to process
if not indices:
break
# Get current box
x1, y1, w1, h1 = boxes[current]
x2_1 = x1 + w1
y2_1 = y1 + h1
area1 = w1 * h1
# Check remaining boxes
i = 0
while i < len(indices):
# Get box to compare
idx = indices[i]
x2, y2, w2, h2 = boxes[idx]
x2_2 = x2 + w2
y2_2 = y2 + h2
area2 = w2 * h2
# Calculate intersection
xx1 = max(x1, x2)
yy1 = max(y1, y2)
xx2 = min(x2_1, x2_2)
yy2 = min(y2_1, y2_2)
# Calculate intersection area
w = max(0, xx2 - xx1)
h = max(0, yy2 - yy1)
intersection = w * h
# Calculate IoU
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
iou = intersection / union
# Remove box if IoU is above threshold
if iou > threshold:
indices.pop(i)
else:
i += 1
return keep
def close(self):
"""Close the model resources"""
if hasattr(self.model, "close"):
self.model.close()
self.model = None
class ModelManager:
"""Model manager for detectors"""
def __init__(self):
self.models = {}
def load(self, models_config: List[Dict]):
"""
Load models from configuration
Args:
models_config: List of model configurations
"""
# Basic color palette for different models
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
for i, model_config in enumerate(models_config):
name = model_config.get("name")
path = model_config.get("path")
size = model_config.get("size", [640, 640])
if not name or not path or not os.path.exists(path):
print(f"Skipping model: {name} - Invalid configuration")
continue
try:
# Use color from palette
color = palette[i % len(palette)]
# Create detector for Python model
detector = PythonModelDetector(
model_name=name,
model_path=path,
input_width=size[0],
input_height=size[1],
color=color
)
self.models[name] = detector
print(f"Model loaded: {name} ({path})")
except Exception as e:
print(f"Failed to load model {name}: {str(e)}")
def get(self, name: str) -> Optional[PythonModelDetector]:
"""Get detector by name"""
return self.models.get(name)
def all(self) -> List[PythonModelDetector]:
"""Get all detectors"""
return list(self.models.values())
def close(self):
"""Close all detectors"""
for detector in self.models.values():
try:
detector.close()
except:
pass
self.models.clear()

View File

@@ -0,0 +1,164 @@
import os
import base64
import cv2
import json
import numpy as np
from typing import Dict, List
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse
from app.detector import ModelManager
# Initialize FastAPI app
app = FastAPI(
title="Python Model Inference Service",
description="API for object detection using Python models",
version="1.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model manager
model_manager = None
# Load models from configuration
@app.on_event("startup")
async def startup_event():
global model_manager
model_manager = ModelManager()
# Look for models.json configuration file
models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models", "models.json"))
if os.path.exists(models_json_path):
try:
with open(models_json_path, "r") as f:
models_config = json.load(f)
model_manager.load(models_config)
print(f"Loaded model configuration from {models_json_path}")
except Exception as e:
print(f"Failed to load models from {models_json_path}: {str(e)}")
else:
print(f"Models configuration not found: {models_json_path}")
@app.on_event("shutdown")
async def shutdown_event():
global model_manager
if model_manager:
model_manager.close()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/api/models", response_model=ModelsResponse)
async def get_models():
"""Get available models"""
global model_manager
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
detectors = model_manager.all()
models = []
for detector in detectors:
model_info = ModelInfo(
name=detector.model_name,
path=getattr(detector, 'model_path', ''),
size=[detector.input_width, detector.input_height],
backend="Python",
loaded=True
)
models.append(model_info)
return ModelsResponse(models=models)
@app.post("/api/detect", response_model=DetectionResponse)
async def detect(request: DetectionRequest):
"""Detect objects in an image"""
global model_manager
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
# Get detector for requested model
detector = model_manager.get(request.model_name)
if not detector:
raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}")
# Decode base64 image
try:
# Remove data URL prefix if present
if "base64," in request.image_data:
image_data = request.image_data.split("base64,")[1]
else:
image_data = request.image_data
# Decode base64 image
image_bytes = base64.b64decode(image_data)
nparr = np.frombuffer(image_bytes, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image data")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}")
# Run detection
detections, inference_time = detector.detect(image)
return DetectionResponse(
model_name=request.model_name,
detections=detections,
inference_time=inference_time
)
@app.post("/api/detect/file", response_model=DetectionResponse)
async def detect_file(
model_name: str,
file: UploadFile = File(...)
):
"""Detect objects in an uploaded image file"""
global model_manager
if not model_manager:
raise HTTPException(status_code=500, detail="Model manager not initialized")
# Get detector for requested model
detector = model_manager.get(model_name)
if not detector:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
# Read uploaded file
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image data")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
# Run detection
detections, inference_time = detector.detect(image)
return DetectionResponse(
model_name=model_name,
detections=detections,
inference_time=inference_time
)
if __name__ == "__main__":
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)

View File

@@ -0,0 +1,40 @@
from pydantic import BaseModel
from typing import List, Optional
class Detection(BaseModel):
"""Object detection result"""
label: str
confidence: float
x: int
y: int
width: int
height: int
color: int = 0x00FF00 # Default green color
class DetectionRequest(BaseModel):
"""Request for model inference on image data"""
model_name: str
image_data: str # Base64 encoded image
class DetectionResponse(BaseModel):
"""Response with detection results"""
model_name: str
detections: List[Detection]
inference_time: float # Time in milliseconds
class ModelInfo(BaseModel):
"""Model information"""
name: str
path: str
size: List[int] # [width, height]
backend: str = "ONNX"
loaded: bool = False
class ModelsResponse(BaseModel):
"""Response with available models"""
models: List[ModelInfo]