将检测任务迁移python

This commit is contained in:
2025-09-30 14:23:33 +08:00
parent 3fe5f8083d
commit 39d39a7a24
69 changed files with 7921 additions and 1836 deletions

View File

@@ -0,0 +1,311 @@
import os
import cv2
import numpy as np
import time
from typing import List, Dict, Tuple, Optional
import importlib.util
import sys
from app.models import Detection
class PythonModelDetector:
"""Object detector using native Python models"""
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
"""
Initialize detector with Python model
Args:
model_name: Name of the model
model_path: Path to the Python model file (.py)
input_width: Input width for the model
input_height: Input height for the model
color: RGB color for detection boxes (default: green)
"""
self.model_name = model_name
self.input_width = input_width
self.input_height = input_height
self.color = color
# Convert color from RGB to BGR (OpenCV uses BGR)
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
# Default confidence thresholds
self.conf_threshold = 0.25
self.nms_threshold = 0.45
# Load the Python model dynamically
self._load_python_model(model_path)
# Load class names if available
self.classes = []
model_dir = os.path.dirname(model_path)
classes_path = os.path.join(model_dir, "classes.txt")
if os.path.exists(classes_path):
with open(classes_path, 'r') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
def _load_python_model(self, model_path: str):
"""Load Python model dynamically"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
# Get model directory and file name
model_dir = os.path.dirname(model_path)
model_file = os.path.basename(model_path)
model_name = os.path.splitext(model_file)[0]
# Add model directory to system path
if model_dir not in sys.path:
sys.path.append(model_dir)
# Import the model module
spec = importlib.util.spec_from_file_location(model_name, model_path)
if spec is None:
raise ImportError(f"Failed to load model specification: {model_path}")
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Check if the module has the required interface
if not hasattr(model_module, "Model"):
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
# Create model instance
self.model = model_module.Model()
# Check if model has the required methods
if not hasattr(self.model, "predict"):
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
def preprocess(self, img: np.ndarray) -> np.ndarray:
"""Preprocess image for model input"""
# Ensure BGR image
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # BGRA
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
# Resize to model input size
resized = cv2.resize(img, (self.input_width, self.input_height))
# Use model's preprocess method if available
if hasattr(self.model, "preprocess"):
return self.model.preprocess(resized)
# Default preprocessing: normalize to [0, 1]
return resized / 255.0
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
"""
Detect objects in an image
Args:
img: Input image in BGR format (OpenCV)
Returns:
List of Detection objects and inference time in milliseconds
"""
if img is None or img.size == 0:
return [], 0.0
# Original image dimensions
img_height, img_width = img.shape[:2]
# Preprocess image
processed_img = self.preprocess(img)
# Measure inference time
start_time = time.time()
try:
# Run inference using model's predict method
# Expected return format from model's predict:
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
# bbox: (x, y, w, h) normalized [0-1]
model_results = self.model.predict(processed_img)
# Calculate inference time in milliseconds
inference_time = (time.time() - start_time) * 1000
# Convert model results to Detection objects
detections = []
for result in model_results:
# Skip low confidence detections
confidence = result.get('confidence', 0)
if confidence < self.conf_threshold:
continue
# Get bounding box (normalized coordinates)
bbox = result.get('bbox', [0, 0, 0, 0])
# Denormalize bbox to image coordinates
x = int(bbox[0] * img_width)
y = int(bbox[1] * img_height)
w = int(bbox[2] * img_width)
h = int(bbox[3] * img_height)
# Skip invalid boxes
if w <= 0 or h <= 0:
continue
# Get class ID and name
class_id = result.get('class_id', 0)
class_name = f"cls{class_id}"
if 0 <= class_id < len(self.classes):
class_name = self.classes[class_id]
# Create Detection object
label = f"[{self.model_name}] {class_name}"
detection = Detection(
label=label,
confidence=confidence,
x=x,
y=y,
width=w,
height=h,
color=self.color
)
detections.append(detection)
# Apply NMS if model doesn't do it internally
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
return detections, inference_time
else:
# Convert detections to boxes and scores
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
scores = [d.confidence for d in detections]
if boxes:
# Apply NMS
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
detections = [detections[i] for i in indices]
return detections, inference_time
except Exception as e:
print(f"Error during detection: {str(e)}")
return [], (time.time() - start_time) * 1000
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
# Sort by score in descending order
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
keep = []
while indices:
# Get index with highest score
current = indices.pop(0)
keep.append(current)
# No more indices to process
if not indices:
break
# Get current box
x1, y1, w1, h1 = boxes[current]
x2_1 = x1 + w1
y2_1 = y1 + h1
area1 = w1 * h1
# Check remaining boxes
i = 0
while i < len(indices):
# Get box to compare
idx = indices[i]
x2, y2, w2, h2 = boxes[idx]
x2_2 = x2 + w2
y2_2 = y2 + h2
area2 = w2 * h2
# Calculate intersection
xx1 = max(x1, x2)
yy1 = max(y1, y2)
xx2 = min(x2_1, x2_2)
yy2 = min(y2_1, y2_2)
# Calculate intersection area
w = max(0, xx2 - xx1)
h = max(0, yy2 - yy1)
intersection = w * h
# Calculate IoU
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
iou = intersection / union
# Remove box if IoU is above threshold
if iou > threshold:
indices.pop(i)
else:
i += 1
return keep
def close(self):
"""Close the model resources"""
if hasattr(self.model, "close"):
self.model.close()
self.model = None
class ModelManager:
"""Model manager for detectors"""
def __init__(self):
self.models = {}
def load(self, models_config: List[Dict]):
"""
Load models from configuration
Args:
models_config: List of model configurations
"""
# Basic color palette for different models
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
for i, model_config in enumerate(models_config):
name = model_config.get("name")
path = model_config.get("path")
size = model_config.get("size", [640, 640])
if not name or not path or not os.path.exists(path):
print(f"Skipping model: {name} - Invalid configuration")
continue
try:
# Use color from palette
color = palette[i % len(palette)]
# Create detector for Python model
detector = PythonModelDetector(
model_name=name,
model_path=path,
input_width=size[0],
input_height=size[1],
color=color
)
self.models[name] = detector
print(f"Model loaded: {name} ({path})")
except Exception as e:
print(f"Failed to load model {name}: {str(e)}")
def get(self, name: str) -> Optional[PythonModelDetector]:
"""Get detector by name"""
return self.models.get(name)
def all(self) -> List[PythonModelDetector]:
"""Get all detectors"""
return list(self.models.values())
def close(self):
"""Close all detectors"""
for detector in self.models.values():
try:
detector.close()
except:
pass
self.models.clear()