将检测任务迁移python
This commit is contained in:
1
python-inference-service/app/__init__.py
Normal file
1
python-inference-service/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Python Inference Service package
|
||||
311
python-inference-service/app/detector.py
Normal file
311
python-inference-service/app/detector.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from app.models import Detection
|
||||
|
||||
|
||||
class PythonModelDetector:
|
||||
"""Object detector using native Python models"""
|
||||
|
||||
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
|
||||
"""
|
||||
Initialize detector with Python model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_path: Path to the Python model file (.py)
|
||||
input_width: Input width for the model
|
||||
input_height: Input height for the model
|
||||
color: RGB color for detection boxes (default: green)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.input_width = input_width
|
||||
self.input_height = input_height
|
||||
self.color = color
|
||||
|
||||
# Convert color from RGB to BGR (OpenCV uses BGR)
|
||||
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
|
||||
|
||||
# Default confidence thresholds
|
||||
self.conf_threshold = 0.25
|
||||
self.nms_threshold = 0.45
|
||||
|
||||
# Load the Python model dynamically
|
||||
self._load_python_model(model_path)
|
||||
|
||||
# Load class names if available
|
||||
self.classes = []
|
||||
model_dir = os.path.dirname(model_path)
|
||||
classes_path = os.path.join(model_dir, "classes.txt")
|
||||
if os.path.exists(classes_path):
|
||||
with open(classes_path, 'r') as f:
|
||||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
def _load_python_model(self, model_path: str):
|
||||
"""Load Python model dynamically"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Get model directory and file name
|
||||
model_dir = os.path.dirname(model_path)
|
||||
model_file = os.path.basename(model_path)
|
||||
model_name = os.path.splitext(model_file)[0]
|
||||
|
||||
# Add model directory to system path
|
||||
if model_dir not in sys.path:
|
||||
sys.path.append(model_dir)
|
||||
|
||||
# Import the model module
|
||||
spec = importlib.util.spec_from_file_location(model_name, model_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Failed to load model specification: {model_path}")
|
||||
|
||||
model_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(model_module)
|
||||
|
||||
# Check if the module has the required interface
|
||||
if not hasattr(model_module, "Model"):
|
||||
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
|
||||
|
||||
# Create model instance
|
||||
self.model = model_module.Model()
|
||||
|
||||
# Check if model has the required methods
|
||||
if not hasattr(self.model, "predict"):
|
||||
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess image for model input"""
|
||||
# Ensure BGR image
|
||||
if len(img.shape) == 2: # Grayscale
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4: # BGRA
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
|
||||
# Resize to model input size
|
||||
resized = cv2.resize(img, (self.input_width, self.input_height))
|
||||
|
||||
# Use model's preprocess method if available
|
||||
if hasattr(self.model, "preprocess"):
|
||||
return self.model.preprocess(resized)
|
||||
|
||||
# Default preprocessing: normalize to [0, 1]
|
||||
return resized / 255.0
|
||||
|
||||
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
|
||||
"""
|
||||
Detect objects in an image
|
||||
|
||||
Args:
|
||||
img: Input image in BGR format (OpenCV)
|
||||
|
||||
Returns:
|
||||
List of Detection objects and inference time in milliseconds
|
||||
"""
|
||||
if img is None or img.size == 0:
|
||||
return [], 0.0
|
||||
|
||||
# Original image dimensions
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# Preprocess image
|
||||
processed_img = self.preprocess(img)
|
||||
|
||||
# Measure inference time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference using model's predict method
|
||||
# Expected return format from model's predict:
|
||||
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
|
||||
# bbox: (x, y, w, h) normalized [0-1]
|
||||
model_results = self.model.predict(processed_img)
|
||||
|
||||
# Calculate inference time in milliseconds
|
||||
inference_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert model results to Detection objects
|
||||
detections = []
|
||||
|
||||
for result in model_results:
|
||||
# Skip low confidence detections
|
||||
confidence = result.get('confidence', 0)
|
||||
if confidence < self.conf_threshold:
|
||||
continue
|
||||
|
||||
# Get bounding box (normalized coordinates)
|
||||
bbox = result.get('bbox', [0, 0, 0, 0])
|
||||
|
||||
# Denormalize bbox to image coordinates
|
||||
x = int(bbox[0] * img_width)
|
||||
y = int(bbox[1] * img_height)
|
||||
w = int(bbox[2] * img_width)
|
||||
h = int(bbox[3] * img_height)
|
||||
|
||||
# Skip invalid boxes
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
|
||||
# Get class ID and name
|
||||
class_id = result.get('class_id', 0)
|
||||
class_name = f"cls{class_id}"
|
||||
if 0 <= class_id < len(self.classes):
|
||||
class_name = self.classes[class_id]
|
||||
|
||||
# Create Detection object
|
||||
label = f"[{self.model_name}] {class_name}"
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
x=x,
|
||||
y=y,
|
||||
width=w,
|
||||
height=h,
|
||||
color=self.color
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
# Apply NMS if model doesn't do it internally
|
||||
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
|
||||
return detections, inference_time
|
||||
else:
|
||||
# Convert detections to boxes and scores
|
||||
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
|
||||
scores = [d.confidence for d in detections]
|
||||
|
||||
if boxes:
|
||||
# Apply NMS
|
||||
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
return detections, inference_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during detection: {str(e)}")
|
||||
return [], (time.time() - start_time) * 1000
|
||||
|
||||
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
|
||||
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
|
||||
# Sort by score in descending order
|
||||
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
|
||||
keep = []
|
||||
while indices:
|
||||
# Get index with highest score
|
||||
current = indices.pop(0)
|
||||
keep.append(current)
|
||||
|
||||
# No more indices to process
|
||||
if not indices:
|
||||
break
|
||||
|
||||
# Get current box
|
||||
x1, y1, w1, h1 = boxes[current]
|
||||
x2_1 = x1 + w1
|
||||
y2_1 = y1 + h1
|
||||
area1 = w1 * h1
|
||||
|
||||
# Check remaining boxes
|
||||
i = 0
|
||||
while i < len(indices):
|
||||
# Get box to compare
|
||||
idx = indices[i]
|
||||
x2, y2, w2, h2 = boxes[idx]
|
||||
x2_2 = x2 + w2
|
||||
y2_2 = y2 + h2
|
||||
area2 = w2 * h2
|
||||
|
||||
# Calculate intersection
|
||||
xx1 = max(x1, x2)
|
||||
yy1 = max(y1, y2)
|
||||
xx2 = min(x2_1, x2_2)
|
||||
yy2 = min(y2_1, y2_2)
|
||||
|
||||
# Calculate intersection area
|
||||
w = max(0, xx2 - xx1)
|
||||
h = max(0, yy2 - yy1)
|
||||
intersection = w * h
|
||||
|
||||
# Calculate IoU
|
||||
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
|
||||
iou = intersection / union
|
||||
|
||||
# Remove box if IoU is above threshold
|
||||
if iou > threshold:
|
||||
indices.pop(i)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return keep
|
||||
|
||||
def close(self):
|
||||
"""Close the model resources"""
|
||||
if hasattr(self.model, "close"):
|
||||
self.model.close()
|
||||
self.model = None
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Model manager for detectors"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
|
||||
def load(self, models_config: List[Dict]):
|
||||
"""
|
||||
Load models from configuration
|
||||
|
||||
Args:
|
||||
models_config: List of model configurations
|
||||
"""
|
||||
# Basic color palette for different models
|
||||
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
|
||||
|
||||
for i, model_config in enumerate(models_config):
|
||||
name = model_config.get("name")
|
||||
path = model_config.get("path")
|
||||
size = model_config.get("size", [640, 640])
|
||||
|
||||
if not name or not path or not os.path.exists(path):
|
||||
print(f"Skipping model: {name} - Invalid configuration")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use color from palette
|
||||
color = palette[i % len(palette)]
|
||||
|
||||
# Create detector for Python model
|
||||
detector = PythonModelDetector(
|
||||
model_name=name,
|
||||
model_path=path,
|
||||
input_width=size[0],
|
||||
input_height=size[1],
|
||||
color=color
|
||||
)
|
||||
|
||||
self.models[name] = detector
|
||||
print(f"Model loaded: {name} ({path})")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model {name}: {str(e)}")
|
||||
|
||||
def get(self, name: str) -> Optional[PythonModelDetector]:
|
||||
"""Get detector by name"""
|
||||
return self.models.get(name)
|
||||
|
||||
def all(self) -> List[PythonModelDetector]:
|
||||
"""Get all detectors"""
|
||||
return list(self.models.values())
|
||||
|
||||
def close(self):
|
||||
"""Close all detectors"""
|
||||
for detector in self.models.values():
|
||||
try:
|
||||
detector.close()
|
||||
except:
|
||||
pass
|
||||
self.models.clear()
|
||||
164
python-inference-service/app/main.py
Normal file
164
python-inference-service/app/main.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import os
|
||||
import base64
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse
|
||||
from app.detector import ModelManager
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Python Model Inference Service",
|
||||
description="API for object detection using Python models",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize model manager
|
||||
model_manager = None
|
||||
|
||||
# Load models from configuration
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global model_manager
|
||||
model_manager = ModelManager()
|
||||
|
||||
# Look for models.json configuration file
|
||||
models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models", "models.json"))
|
||||
|
||||
if os.path.exists(models_json_path):
|
||||
try:
|
||||
with open(models_json_path, "r") as f:
|
||||
models_config = json.load(f)
|
||||
model_manager.load(models_config)
|
||||
print(f"Loaded model configuration from {models_json_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to load models from {models_json_path}: {str(e)}")
|
||||
else:
|
||||
print(f"Models configuration not found: {models_json_path}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global model_manager
|
||||
if model_manager:
|
||||
model_manager.close()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/models", response_model=ModelsResponse)
|
||||
async def get_models():
|
||||
"""Get available models"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
detectors = model_manager.all()
|
||||
models = []
|
||||
|
||||
for detector in detectors:
|
||||
model_info = ModelInfo(
|
||||
name=detector.model_name,
|
||||
path=getattr(detector, 'model_path', ''),
|
||||
size=[detector.input_width, detector.input_height],
|
||||
backend="Python",
|
||||
loaded=True
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
return ModelsResponse(models=models)
|
||||
|
||||
@app.post("/api/detect", response_model=DetectionResponse)
|
||||
async def detect(request: DetectionRequest):
|
||||
"""Detect objects in an image"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
# Get detector for requested model
|
||||
detector = model_manager.get(request.model_name)
|
||||
if not detector:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}")
|
||||
|
||||
# Decode base64 image
|
||||
try:
|
||||
# Remove data URL prefix if present
|
||||
if "base64," in request.image_data:
|
||||
image_data = request.image_data.split("base64,")[1]
|
||||
else:
|
||||
image_data = request.image_data
|
||||
|
||||
# Decode base64 image
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}")
|
||||
|
||||
# Run detection
|
||||
detections, inference_time = detector.detect(image)
|
||||
|
||||
return DetectionResponse(
|
||||
model_name=request.model_name,
|
||||
detections=detections,
|
||||
inference_time=inference_time
|
||||
)
|
||||
|
||||
@app.post("/api/detect/file", response_model=DetectionResponse)
|
||||
async def detect_file(
|
||||
model_name: str,
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""Detect objects in an uploaded image file"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
# Get detector for requested model
|
||||
detector = model_manager.get(model_name)
|
||||
if not detector:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
||||
|
||||
# Read uploaded file
|
||||
try:
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
|
||||
|
||||
# Run detection
|
||||
detections, inference_time = detector.detect(image)
|
||||
|
||||
return DetectionResponse(
|
||||
model_name=model_name,
|
||||
detections=detections,
|
||||
inference_time=inference_time
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
40
python-inference-service/app/models.py
Normal file
40
python-inference-service/app/models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Detection(BaseModel):
|
||||
"""Object detection result"""
|
||||
label: str
|
||||
confidence: float
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
color: int = 0x00FF00 # Default green color
|
||||
|
||||
|
||||
class DetectionRequest(BaseModel):
|
||||
"""Request for model inference on image data"""
|
||||
model_name: str
|
||||
image_data: str # Base64 encoded image
|
||||
|
||||
|
||||
class DetectionResponse(BaseModel):
|
||||
"""Response with detection results"""
|
||||
model_name: str
|
||||
detections: List[Detection]
|
||||
inference_time: float # Time in milliseconds
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model information"""
|
||||
name: str
|
||||
path: str
|
||||
size: List[int] # [width, height]
|
||||
backend: str = "ONNX"
|
||||
loaded: bool = False
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""Response with available models"""
|
||||
models: List[ModelInfo]
|
||||
Reference in New Issue
Block a user