将检测任务迁移python
This commit is contained in:
311
python-inference-service/app/detector.py
Normal file
311
python-inference-service/app/detector.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from app.models import Detection
|
||||
|
||||
|
||||
class PythonModelDetector:
|
||||
"""Object detector using native Python models"""
|
||||
|
||||
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
|
||||
"""
|
||||
Initialize detector with Python model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_path: Path to the Python model file (.py)
|
||||
input_width: Input width for the model
|
||||
input_height: Input height for the model
|
||||
color: RGB color for detection boxes (default: green)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.input_width = input_width
|
||||
self.input_height = input_height
|
||||
self.color = color
|
||||
|
||||
# Convert color from RGB to BGR (OpenCV uses BGR)
|
||||
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
|
||||
|
||||
# Default confidence thresholds
|
||||
self.conf_threshold = 0.25
|
||||
self.nms_threshold = 0.45
|
||||
|
||||
# Load the Python model dynamically
|
||||
self._load_python_model(model_path)
|
||||
|
||||
# Load class names if available
|
||||
self.classes = []
|
||||
model_dir = os.path.dirname(model_path)
|
||||
classes_path = os.path.join(model_dir, "classes.txt")
|
||||
if os.path.exists(classes_path):
|
||||
with open(classes_path, 'r') as f:
|
||||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
def _load_python_model(self, model_path: str):
|
||||
"""Load Python model dynamically"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Get model directory and file name
|
||||
model_dir = os.path.dirname(model_path)
|
||||
model_file = os.path.basename(model_path)
|
||||
model_name = os.path.splitext(model_file)[0]
|
||||
|
||||
# Add model directory to system path
|
||||
if model_dir not in sys.path:
|
||||
sys.path.append(model_dir)
|
||||
|
||||
# Import the model module
|
||||
spec = importlib.util.spec_from_file_location(model_name, model_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Failed to load model specification: {model_path}")
|
||||
|
||||
model_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(model_module)
|
||||
|
||||
# Check if the module has the required interface
|
||||
if not hasattr(model_module, "Model"):
|
||||
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
|
||||
|
||||
# Create model instance
|
||||
self.model = model_module.Model()
|
||||
|
||||
# Check if model has the required methods
|
||||
if not hasattr(self.model, "predict"):
|
||||
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess image for model input"""
|
||||
# Ensure BGR image
|
||||
if len(img.shape) == 2: # Grayscale
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4: # BGRA
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
|
||||
# Resize to model input size
|
||||
resized = cv2.resize(img, (self.input_width, self.input_height))
|
||||
|
||||
# Use model's preprocess method if available
|
||||
if hasattr(self.model, "preprocess"):
|
||||
return self.model.preprocess(resized)
|
||||
|
||||
# Default preprocessing: normalize to [0, 1]
|
||||
return resized / 255.0
|
||||
|
||||
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
|
||||
"""
|
||||
Detect objects in an image
|
||||
|
||||
Args:
|
||||
img: Input image in BGR format (OpenCV)
|
||||
|
||||
Returns:
|
||||
List of Detection objects and inference time in milliseconds
|
||||
"""
|
||||
if img is None or img.size == 0:
|
||||
return [], 0.0
|
||||
|
||||
# Original image dimensions
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# Preprocess image
|
||||
processed_img = self.preprocess(img)
|
||||
|
||||
# Measure inference time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference using model's predict method
|
||||
# Expected return format from model's predict:
|
||||
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
|
||||
# bbox: (x, y, w, h) normalized [0-1]
|
||||
model_results = self.model.predict(processed_img)
|
||||
|
||||
# Calculate inference time in milliseconds
|
||||
inference_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert model results to Detection objects
|
||||
detections = []
|
||||
|
||||
for result in model_results:
|
||||
# Skip low confidence detections
|
||||
confidence = result.get('confidence', 0)
|
||||
if confidence < self.conf_threshold:
|
||||
continue
|
||||
|
||||
# Get bounding box (normalized coordinates)
|
||||
bbox = result.get('bbox', [0, 0, 0, 0])
|
||||
|
||||
# Denormalize bbox to image coordinates
|
||||
x = int(bbox[0] * img_width)
|
||||
y = int(bbox[1] * img_height)
|
||||
w = int(bbox[2] * img_width)
|
||||
h = int(bbox[3] * img_height)
|
||||
|
||||
# Skip invalid boxes
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
|
||||
# Get class ID and name
|
||||
class_id = result.get('class_id', 0)
|
||||
class_name = f"cls{class_id}"
|
||||
if 0 <= class_id < len(self.classes):
|
||||
class_name = self.classes[class_id]
|
||||
|
||||
# Create Detection object
|
||||
label = f"[{self.model_name}] {class_name}"
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
x=x,
|
||||
y=y,
|
||||
width=w,
|
||||
height=h,
|
||||
color=self.color
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
# Apply NMS if model doesn't do it internally
|
||||
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
|
||||
return detections, inference_time
|
||||
else:
|
||||
# Convert detections to boxes and scores
|
||||
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
|
||||
scores = [d.confidence for d in detections]
|
||||
|
||||
if boxes:
|
||||
# Apply NMS
|
||||
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
return detections, inference_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during detection: {str(e)}")
|
||||
return [], (time.time() - start_time) * 1000
|
||||
|
||||
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
|
||||
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
|
||||
# Sort by score in descending order
|
||||
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
|
||||
keep = []
|
||||
while indices:
|
||||
# Get index with highest score
|
||||
current = indices.pop(0)
|
||||
keep.append(current)
|
||||
|
||||
# No more indices to process
|
||||
if not indices:
|
||||
break
|
||||
|
||||
# Get current box
|
||||
x1, y1, w1, h1 = boxes[current]
|
||||
x2_1 = x1 + w1
|
||||
y2_1 = y1 + h1
|
||||
area1 = w1 * h1
|
||||
|
||||
# Check remaining boxes
|
||||
i = 0
|
||||
while i < len(indices):
|
||||
# Get box to compare
|
||||
idx = indices[i]
|
||||
x2, y2, w2, h2 = boxes[idx]
|
||||
x2_2 = x2 + w2
|
||||
y2_2 = y2 + h2
|
||||
area2 = w2 * h2
|
||||
|
||||
# Calculate intersection
|
||||
xx1 = max(x1, x2)
|
||||
yy1 = max(y1, y2)
|
||||
xx2 = min(x2_1, x2_2)
|
||||
yy2 = min(y2_1, y2_2)
|
||||
|
||||
# Calculate intersection area
|
||||
w = max(0, xx2 - xx1)
|
||||
h = max(0, yy2 - yy1)
|
||||
intersection = w * h
|
||||
|
||||
# Calculate IoU
|
||||
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
|
||||
iou = intersection / union
|
||||
|
||||
# Remove box if IoU is above threshold
|
||||
if iou > threshold:
|
||||
indices.pop(i)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return keep
|
||||
|
||||
def close(self):
|
||||
"""Close the model resources"""
|
||||
if hasattr(self.model, "close"):
|
||||
self.model.close()
|
||||
self.model = None
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Model manager for detectors"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
|
||||
def load(self, models_config: List[Dict]):
|
||||
"""
|
||||
Load models from configuration
|
||||
|
||||
Args:
|
||||
models_config: List of model configurations
|
||||
"""
|
||||
# Basic color palette for different models
|
||||
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
|
||||
|
||||
for i, model_config in enumerate(models_config):
|
||||
name = model_config.get("name")
|
||||
path = model_config.get("path")
|
||||
size = model_config.get("size", [640, 640])
|
||||
|
||||
if not name or not path or not os.path.exists(path):
|
||||
print(f"Skipping model: {name} - Invalid configuration")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use color from palette
|
||||
color = palette[i % len(palette)]
|
||||
|
||||
# Create detector for Python model
|
||||
detector = PythonModelDetector(
|
||||
model_name=name,
|
||||
model_path=path,
|
||||
input_width=size[0],
|
||||
input_height=size[1],
|
||||
color=color
|
||||
)
|
||||
|
||||
self.models[name] = detector
|
||||
print(f"Model loaded: {name} ({path})")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model {name}: {str(e)}")
|
||||
|
||||
def get(self, name: str) -> Optional[PythonModelDetector]:
|
||||
"""Get detector by name"""
|
||||
return self.models.get(name)
|
||||
|
||||
def all(self) -> List[PythonModelDetector]:
|
||||
"""Get all detectors"""
|
||||
return list(self.models.values())
|
||||
|
||||
def close(self):
|
||||
"""Close all detectors"""
|
||||
for detector in self.models.values():
|
||||
try:
|
||||
detector.close()
|
||||
except:
|
||||
pass
|
||||
self.models.clear()
|
||||
Reference in New Issue
Block a user