311 lines
11 KiB
Python
311 lines
11 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
import time
|
|
from typing import List, Dict, Tuple, Optional
|
|
import importlib.util
|
|
import sys
|
|
|
|
from app.models import Detection
|
|
|
|
|
|
class PythonModelDetector:
|
|
"""Object detector using native Python models"""
|
|
|
|
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
|
|
"""
|
|
Initialize detector with Python model
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
model_path: Path to the Python model file (.py)
|
|
input_width: Input width for the model
|
|
input_height: Input height for the model
|
|
color: RGB color for detection boxes (default: green)
|
|
"""
|
|
self.model_name = model_name
|
|
self.input_width = input_width
|
|
self.input_height = input_height
|
|
self.color = color
|
|
|
|
# Convert color from RGB to BGR (OpenCV uses BGR)
|
|
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
|
|
|
|
# Default confidence thresholds
|
|
self.conf_threshold = 0.25
|
|
self.nms_threshold = 0.45
|
|
|
|
# Load the Python model dynamically
|
|
self._load_python_model(model_path)
|
|
|
|
# Load class names if available
|
|
self.classes = []
|
|
model_dir = os.path.dirname(model_path)
|
|
classes_path = os.path.join(model_dir, "classes.txt")
|
|
if os.path.exists(classes_path):
|
|
with open(classes_path, 'r') as f:
|
|
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
|
|
|
def _load_python_model(self, model_path: str):
|
|
"""Load Python model dynamically"""
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
|
# Get model directory and file name
|
|
model_dir = os.path.dirname(model_path)
|
|
model_file = os.path.basename(model_path)
|
|
model_name = os.path.splitext(model_file)[0]
|
|
|
|
# Add model directory to system path
|
|
if model_dir not in sys.path:
|
|
sys.path.append(model_dir)
|
|
|
|
# Import the model module
|
|
spec = importlib.util.spec_from_file_location(model_name, model_path)
|
|
if spec is None:
|
|
raise ImportError(f"Failed to load model specification: {model_path}")
|
|
|
|
model_module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(model_module)
|
|
|
|
# Check if the module has the required interface
|
|
if not hasattr(model_module, "Model"):
|
|
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
|
|
|
|
# Create model instance
|
|
self.model = model_module.Model()
|
|
|
|
# Check if model has the required methods
|
|
if not hasattr(self.model, "predict"):
|
|
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
|
|
|
|
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
"""Preprocess image for model input"""
|
|
# Ensure BGR image
|
|
if len(img.shape) == 2: # Grayscale
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
elif img.shape[2] == 4: # BGRA
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
|
|
|
# Resize to model input size
|
|
resized = cv2.resize(img, (self.input_width, self.input_height))
|
|
|
|
# Use model's preprocess method if available
|
|
if hasattr(self.model, "preprocess"):
|
|
return self.model.preprocess(resized)
|
|
|
|
# Default preprocessing: normalize to [0, 1]
|
|
return resized / 255.0
|
|
|
|
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
|
|
"""
|
|
Detect objects in an image
|
|
|
|
Args:
|
|
img: Input image in BGR format (OpenCV)
|
|
|
|
Returns:
|
|
List of Detection objects and inference time in milliseconds
|
|
"""
|
|
if img is None or img.size == 0:
|
|
return [], 0.0
|
|
|
|
# Original image dimensions
|
|
img_height, img_width = img.shape[:2]
|
|
|
|
# Preprocess image
|
|
processed_img = self.preprocess(img)
|
|
|
|
# Measure inference time
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Run inference using model's predict method
|
|
# Expected return format from model's predict:
|
|
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
|
|
# bbox: (x, y, w, h) normalized [0-1]
|
|
model_results = self.model.predict(processed_img)
|
|
|
|
# Calculate inference time in milliseconds
|
|
inference_time = (time.time() - start_time) * 1000
|
|
|
|
# Convert model results to Detection objects
|
|
detections = []
|
|
|
|
for result in model_results:
|
|
# Skip low confidence detections
|
|
confidence = result.get('confidence', 0)
|
|
if confidence < self.conf_threshold:
|
|
continue
|
|
|
|
# Get bounding box (normalized coordinates)
|
|
bbox = result.get('bbox', [0, 0, 0, 0])
|
|
|
|
# Denormalize bbox to image coordinates
|
|
x = int(bbox[0] * img_width)
|
|
y = int(bbox[1] * img_height)
|
|
w = int(bbox[2] * img_width)
|
|
h = int(bbox[3] * img_height)
|
|
|
|
# Skip invalid boxes
|
|
if w <= 0 or h <= 0:
|
|
continue
|
|
|
|
# Get class ID and name
|
|
class_id = result.get('class_id', 0)
|
|
class_name = f"cls{class_id}"
|
|
if 0 <= class_id < len(self.classes):
|
|
class_name = self.classes[class_id]
|
|
|
|
# Create Detection object
|
|
label = f"[{self.model_name}] {class_name}"
|
|
detection = Detection(
|
|
label=label,
|
|
confidence=confidence,
|
|
x=x,
|
|
y=y,
|
|
width=w,
|
|
height=h,
|
|
color=self.color
|
|
)
|
|
detections.append(detection)
|
|
|
|
# Apply NMS if model doesn't do it internally
|
|
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
|
|
return detections, inference_time
|
|
else:
|
|
# Convert detections to boxes and scores
|
|
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
|
|
scores = [d.confidence for d in detections]
|
|
|
|
if boxes:
|
|
# Apply NMS
|
|
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
|
|
detections = [detections[i] for i in indices]
|
|
|
|
return detections, inference_time
|
|
|
|
except Exception as e:
|
|
print(f"Error during detection: {str(e)}")
|
|
return [], (time.time() - start_time) * 1000
|
|
|
|
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
|
|
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
|
|
# Sort by score in descending order
|
|
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
|
|
|
keep = []
|
|
while indices:
|
|
# Get index with highest score
|
|
current = indices.pop(0)
|
|
keep.append(current)
|
|
|
|
# No more indices to process
|
|
if not indices:
|
|
break
|
|
|
|
# Get current box
|
|
x1, y1, w1, h1 = boxes[current]
|
|
x2_1 = x1 + w1
|
|
y2_1 = y1 + h1
|
|
area1 = w1 * h1
|
|
|
|
# Check remaining boxes
|
|
i = 0
|
|
while i < len(indices):
|
|
# Get box to compare
|
|
idx = indices[i]
|
|
x2, y2, w2, h2 = boxes[idx]
|
|
x2_2 = x2 + w2
|
|
y2_2 = y2 + h2
|
|
area2 = w2 * h2
|
|
|
|
# Calculate intersection
|
|
xx1 = max(x1, x2)
|
|
yy1 = max(y1, y2)
|
|
xx2 = min(x2_1, x2_2)
|
|
yy2 = min(y2_1, y2_2)
|
|
|
|
# Calculate intersection area
|
|
w = max(0, xx2 - xx1)
|
|
h = max(0, yy2 - yy1)
|
|
intersection = w * h
|
|
|
|
# Calculate IoU
|
|
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
|
|
iou = intersection / union
|
|
|
|
# Remove box if IoU is above threshold
|
|
if iou > threshold:
|
|
indices.pop(i)
|
|
else:
|
|
i += 1
|
|
|
|
return keep
|
|
|
|
def close(self):
|
|
"""Close the model resources"""
|
|
if hasattr(self.model, "close"):
|
|
self.model.close()
|
|
self.model = None
|
|
|
|
|
|
class ModelManager:
|
|
"""Model manager for detectors"""
|
|
|
|
def __init__(self):
|
|
self.models = {}
|
|
|
|
def load(self, models_config: List[Dict]):
|
|
"""
|
|
Load models from configuration
|
|
|
|
Args:
|
|
models_config: List of model configurations
|
|
"""
|
|
# Basic color palette for different models
|
|
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
|
|
|
|
for i, model_config in enumerate(models_config):
|
|
name = model_config.get("name")
|
|
path = model_config.get("path")
|
|
size = model_config.get("size", [640, 640])
|
|
|
|
if not name or not path or not os.path.exists(path):
|
|
print(f"Skipping model: {name} - Invalid configuration")
|
|
continue
|
|
|
|
try:
|
|
# Use color from palette
|
|
color = palette[i % len(palette)]
|
|
|
|
# Create detector for Python model
|
|
detector = PythonModelDetector(
|
|
model_name=name,
|
|
model_path=path,
|
|
input_width=size[0],
|
|
input_height=size[1],
|
|
color=color
|
|
)
|
|
|
|
self.models[name] = detector
|
|
print(f"Model loaded: {name} ({path})")
|
|
except Exception as e:
|
|
print(f"Failed to load model {name}: {str(e)}")
|
|
|
|
def get(self, name: str) -> Optional[PythonModelDetector]:
|
|
"""Get detector by name"""
|
|
return self.models.get(name)
|
|
|
|
def all(self) -> List[PythonModelDetector]:
|
|
"""Get all detectors"""
|
|
return list(self.models.values())
|
|
|
|
def close(self):
|
|
"""Close all detectors"""
|
|
for detector in self.models.values():
|
|
try:
|
|
detector.close()
|
|
except:
|
|
pass
|
|
self.models.clear() |