将检测任务迁移python
This commit is contained in:
47
python-inference-service/.dockerignore
Normal file
47
python-inference-service/.dockerignore
Normal file
@@ -0,0 +1,47 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# IDE
|
||||
.vscode
|
||||
.idea
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# CI/CD
|
||||
.github
|
||||
.gitlab-ci.yml
|
||||
|
||||
# Documentation
|
||||
*.md
|
||||
README*
|
||||
|
||||
# Scripts
|
||||
*.bat
|
||||
*.sh
|
||||
|
||||
# Large model files (will be mounted as volume)
|
||||
models/*.pt
|
||||
models/*.onnx
|
||||
models/*.pth
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
45
python-inference-service/Dockerfile
Normal file
45
python-inference-service/Dockerfile
Normal file
@@ -0,0 +1,45 @@
|
||||
# 使用支持CUDA的PyTorch基础镜像
|
||||
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 设置pip镜像源
|
||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender-dev \
|
||||
libgomp1 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制requirements.txt
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装Python依赖
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY app/ /app/app/
|
||||
|
||||
# 创建models目录
|
||||
RUN mkdir -p /app/models
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONPATH=/app
|
||||
ENV MODEL_DIR=/app/models
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# 暴露端口(仅内部使用)
|
||||
EXPOSE 8000
|
||||
|
||||
# 健康检查
|
||||
HEALTHCHECK --interval=30s --timeout=10s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# 启动应用
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]
|
||||
252
python-inference-service/README.md
Normal file
252
python-inference-service/README.md
Normal file
@@ -0,0 +1,252 @@
|
||||
# Python推理服务
|
||||
|
||||
基于FastAPI的YOLOv8目标检测推理服务。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持YOLOv8模型推理
|
||||
- RESTful API接口
|
||||
- 支持Base64图像和文件上传
|
||||
- 支持GPU加速(可选)
|
||||
- Docker部署支持
|
||||
|
||||
## 模型要求
|
||||
|
||||
本服务使用**YOLOv8**(Ultralytics)进行目标检测。
|
||||
|
||||
### 模型文件准备
|
||||
|
||||
1. **模型文件**: 将YOLOv8训练好的模型文件命名为`best.pt`,放在`models/`目录下
|
||||
2. **类别文件**: (可选)创建`classes.txt`文件,每行一个类别名称
|
||||
3. **配置文件**: `models.json`配置模型参数
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
python-inference-service/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI应用
|
||||
│ ├── detector.py # 检测器封装
|
||||
│ └── models.py # 数据模型
|
||||
├── models/
|
||||
│ ├── best.pt # YOLOv8模型文件(必需)
|
||||
│ ├── classes.txt # 类别名称(可选)
|
||||
│ ├── yolov8_model.py # YOLOv8模型包装类
|
||||
│ └── models.json # 模型配置
|
||||
├── requirements.txt
|
||||
└── Dockerfile
|
||||
```
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
主要依赖:
|
||||
- `ultralytics>=8.0.0` - YOLOv8框架
|
||||
- `fastapi` - Web框架
|
||||
- `uvicorn` - ASGI服务器
|
||||
- `opencv-python` - 图像处理
|
||||
- `torch` - PyTorch
|
||||
|
||||
## 配置模型
|
||||
|
||||
编辑`models/models.json`:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "yolov8_detector",
|
||||
"path": "models/yolov8_model.py",
|
||||
"size": [640, 640],
|
||||
"comment": "YOLOv8检测模型"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `name`: 模型名称(API调用时使用)
|
||||
- `path`: 模型包装类的路径
|
||||
- `size`: 输入图像尺寸 [宽度, 高度]
|
||||
|
||||
## 启动服务
|
||||
|
||||
### 本地启动
|
||||
|
||||
```bash
|
||||
# 启动服务(默认端口8000)
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
# 或使用启动脚本
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Docker启动
|
||||
|
||||
```bash
|
||||
# 构建镜像
|
||||
docker build -t python-inference-service .
|
||||
|
||||
# 运行容器
|
||||
docker run -p 8000:8000 \
|
||||
-v $(pwd)/models:/app/models \
|
||||
python-inference-service
|
||||
```
|
||||
|
||||
### 使用GPU
|
||||
|
||||
```bash
|
||||
# 确保安装了NVIDIA Docker Runtime
|
||||
docker run --gpus all -p 8000:8000 \
|
||||
-v $(pwd)/models:/app/models \
|
||||
python-inference-service
|
||||
```
|
||||
|
||||
## API接口
|
||||
|
||||
服务启动后访问:http://localhost:8000/docs 查看API文档
|
||||
|
||||
### 1. 健康检查
|
||||
|
||||
```bash
|
||||
GET /health
|
||||
```
|
||||
|
||||
### 2. 获取可用模型列表
|
||||
|
||||
```bash
|
||||
GET /api/models
|
||||
```
|
||||
|
||||
### 3. Base64图像检测
|
||||
|
||||
```bash
|
||||
POST /api/detect
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model_name": "yolov8_detector",
|
||||
"image_data": "base64_encoded_image_string"
|
||||
}
|
||||
```
|
||||
|
||||
### 4. 文件上传检测
|
||||
|
||||
```bash
|
||||
POST /api/detect/file
|
||||
Content-Type: multipart/form-data
|
||||
|
||||
model_name: yolov8_detector
|
||||
file: <image_file>
|
||||
```
|
||||
|
||||
## 响应格式
|
||||
|
||||
```json
|
||||
{
|
||||
"model_name": "yolov8_detector",
|
||||
"detections": [
|
||||
{
|
||||
"label": "[yolov8_detector] 类别名",
|
||||
"confidence": 0.95,
|
||||
"x": 100,
|
||||
"y": 150,
|
||||
"width": 200,
|
||||
"height": 180,
|
||||
"color": 65280
|
||||
}
|
||||
],
|
||||
"inference_time": 45.6
|
||||
}
|
||||
```
|
||||
|
||||
## 自定义模型
|
||||
|
||||
要使用自己训练的YOLOv8模型:
|
||||
|
||||
1. **训练模型**:使用Ultralytics YOLOv8训练您的模型
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('yolov8n.yaml')
|
||||
model.train(data='your_data.yaml', epochs=100)
|
||||
```
|
||||
|
||||
2. **导出模型**:训练完成后会生成`best.pt`文件
|
||||
|
||||
3. **准备类别文件**:创建`classes.txt`
|
||||
```
|
||||
class1
|
||||
class2
|
||||
class3
|
||||
```
|
||||
|
||||
4. **放置文件**:将`best.pt`和`classes.txt`放到`models/`目录
|
||||
|
||||
5. **更新配置**:确保`models.json`配置正确
|
||||
|
||||
6. **重启服务**
|
||||
|
||||
## 环境变量
|
||||
|
||||
- `MODEL_DIR`: 模型目录路径(默认:`/app/models`)
|
||||
- `MODELS_JSON`: 模型配置文件路径(默认:`models/models.json`)
|
||||
|
||||
## 性能优化
|
||||
|
||||
### GPU加速
|
||||
|
||||
服务会自动检测GPU并使用。如果有多张GPU,可以指定:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### 置信度阈值
|
||||
|
||||
在`yolov8_model.py`中调整:
|
||||
|
||||
```python
|
||||
self.conf_threshold = 0.25 # 降低阈值检测更多目标
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 模型加载失败
|
||||
|
||||
```
|
||||
错误:找不到 best.pt
|
||||
解决:确保模型文件在 models/ 目录下
|
||||
```
|
||||
|
||||
### GPU不可用
|
||||
|
||||
```
|
||||
错误:CUDA not available
|
||||
解决:
|
||||
1. 检查NVIDIA驱动
|
||||
2. 检查PyTorch GPU版本
|
||||
3. 检查CUDA版本兼容性
|
||||
```
|
||||
|
||||
### 推理速度慢
|
||||
|
||||
```
|
||||
解决:
|
||||
1. 使用GPU加速
|
||||
2. 使用更小的模型(如yolov8n.pt)
|
||||
3. 减小输入图像尺寸
|
||||
```
|
||||
|
||||
## 开发者
|
||||
|
||||
如需修改或扩展功能,请参考:
|
||||
- `app/main.py` - API路由定义
|
||||
- `app/detector.py` - 检测器基类
|
||||
- `models/yolov8_model.py` - YOLOv8模型包装类
|
||||
|
||||
## 许可证
|
||||
|
||||
[根据项目实际许可证填写]
|
||||
1
python-inference-service/app/__init__.py
Normal file
1
python-inference-service/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Python Inference Service package
|
||||
311
python-inference-service/app/detector.py
Normal file
311
python-inference-service/app/detector.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from app.models import Detection
|
||||
|
||||
|
||||
class PythonModelDetector:
|
||||
"""Object detector using native Python models"""
|
||||
|
||||
def __init__(self, model_name: str, model_path: str, input_width: int, input_height: int, color: int = 0x00FF00):
|
||||
"""
|
||||
Initialize detector with Python model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_path: Path to the Python model file (.py)
|
||||
input_width: Input width for the model
|
||||
input_height: Input height for the model
|
||||
color: RGB color for detection boxes (default: green)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.input_width = input_width
|
||||
self.input_height = input_height
|
||||
self.color = color
|
||||
|
||||
# Convert color from RGB to BGR (OpenCV uses BGR)
|
||||
self.color_bgr = ((color & 0xFF) << 16) | (color & 0xFF00) | ((color >> 16) & 0xFF)
|
||||
|
||||
# Default confidence thresholds
|
||||
self.conf_threshold = 0.25
|
||||
self.nms_threshold = 0.45
|
||||
|
||||
# Load the Python model dynamically
|
||||
self._load_python_model(model_path)
|
||||
|
||||
# Load class names if available
|
||||
self.classes = []
|
||||
model_dir = os.path.dirname(model_path)
|
||||
classes_path = os.path.join(model_dir, "classes.txt")
|
||||
if os.path.exists(classes_path):
|
||||
with open(classes_path, 'r') as f:
|
||||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
def _load_python_model(self, model_path: str):
|
||||
"""Load Python model dynamically"""
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
# Get model directory and file name
|
||||
model_dir = os.path.dirname(model_path)
|
||||
model_file = os.path.basename(model_path)
|
||||
model_name = os.path.splitext(model_file)[0]
|
||||
|
||||
# Add model directory to system path
|
||||
if model_dir not in sys.path:
|
||||
sys.path.append(model_dir)
|
||||
|
||||
# Import the model module
|
||||
spec = importlib.util.spec_from_file_location(model_name, model_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Failed to load model specification: {model_path}")
|
||||
|
||||
model_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(model_module)
|
||||
|
||||
# Check if the module has the required interface
|
||||
if not hasattr(model_module, "Model"):
|
||||
raise AttributeError(f"Model module must define a 'Model' class: {model_path}")
|
||||
|
||||
# Create model instance
|
||||
self.model = model_module.Model()
|
||||
|
||||
# Check if model has the required methods
|
||||
if not hasattr(self.model, "predict"):
|
||||
raise AttributeError(f"Model must implement 'predict' method: {model_path}")
|
||||
|
||||
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess image for model input"""
|
||||
# Ensure BGR image
|
||||
if len(img.shape) == 2: # Grayscale
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4: # BGRA
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
|
||||
# Resize to model input size
|
||||
resized = cv2.resize(img, (self.input_width, self.input_height))
|
||||
|
||||
# Use model's preprocess method if available
|
||||
if hasattr(self.model, "preprocess"):
|
||||
return self.model.preprocess(resized)
|
||||
|
||||
# Default preprocessing: normalize to [0, 1]
|
||||
return resized / 255.0
|
||||
|
||||
def detect(self, img: np.ndarray) -> Tuple[List[Detection], float]:
|
||||
"""
|
||||
Detect objects in an image
|
||||
|
||||
Args:
|
||||
img: Input image in BGR format (OpenCV)
|
||||
|
||||
Returns:
|
||||
List of Detection objects and inference time in milliseconds
|
||||
"""
|
||||
if img is None or img.size == 0:
|
||||
return [], 0.0
|
||||
|
||||
# Original image dimensions
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# Preprocess image
|
||||
processed_img = self.preprocess(img)
|
||||
|
||||
# Measure inference time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference using model's predict method
|
||||
# Expected return format from model's predict:
|
||||
# List of dicts with keys: 'bbox', 'class_id', 'confidence'
|
||||
# bbox: (x, y, w, h) normalized [0-1]
|
||||
model_results = self.model.predict(processed_img)
|
||||
|
||||
# Calculate inference time in milliseconds
|
||||
inference_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Convert model results to Detection objects
|
||||
detections = []
|
||||
|
||||
for result in model_results:
|
||||
# Skip low confidence detections
|
||||
confidence = result.get('confidence', 0)
|
||||
if confidence < self.conf_threshold:
|
||||
continue
|
||||
|
||||
# Get bounding box (normalized coordinates)
|
||||
bbox = result.get('bbox', [0, 0, 0, 0])
|
||||
|
||||
# Denormalize bbox to image coordinates
|
||||
x = int(bbox[0] * img_width)
|
||||
y = int(bbox[1] * img_height)
|
||||
w = int(bbox[2] * img_width)
|
||||
h = int(bbox[3] * img_height)
|
||||
|
||||
# Skip invalid boxes
|
||||
if w <= 0 or h <= 0:
|
||||
continue
|
||||
|
||||
# Get class ID and name
|
||||
class_id = result.get('class_id', 0)
|
||||
class_name = f"cls{class_id}"
|
||||
if 0 <= class_id < len(self.classes):
|
||||
class_name = self.classes[class_id]
|
||||
|
||||
# Create Detection object
|
||||
label = f"[{self.model_name}] {class_name}"
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
x=x,
|
||||
y=y,
|
||||
width=w,
|
||||
height=h,
|
||||
color=self.color
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
# Apply NMS if model doesn't do it internally
|
||||
if hasattr(self.model, "applies_nms") and self.model.applies_nms:
|
||||
return detections, inference_time
|
||||
else:
|
||||
# Convert detections to boxes and scores
|
||||
boxes = [(d.x, d.y, d.width, d.height) for d in detections]
|
||||
scores = [d.confidence for d in detections]
|
||||
|
||||
if boxes:
|
||||
# Apply NMS
|
||||
indices = self._non_max_suppression(boxes, scores, self.nms_threshold)
|
||||
detections = [detections[i] for i in indices]
|
||||
|
||||
return detections, inference_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during detection: {str(e)}")
|
||||
return [], (time.time() - start_time) * 1000
|
||||
|
||||
def _non_max_suppression(self, boxes: List[Tuple], scores: List[float], threshold: float) -> List[int]:
|
||||
"""Apply Non-Maximum Suppression to remove overlapping boxes"""
|
||||
# Sort by score in descending order
|
||||
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
|
||||
keep = []
|
||||
while indices:
|
||||
# Get index with highest score
|
||||
current = indices.pop(0)
|
||||
keep.append(current)
|
||||
|
||||
# No more indices to process
|
||||
if not indices:
|
||||
break
|
||||
|
||||
# Get current box
|
||||
x1, y1, w1, h1 = boxes[current]
|
||||
x2_1 = x1 + w1
|
||||
y2_1 = y1 + h1
|
||||
area1 = w1 * h1
|
||||
|
||||
# Check remaining boxes
|
||||
i = 0
|
||||
while i < len(indices):
|
||||
# Get box to compare
|
||||
idx = indices[i]
|
||||
x2, y2, w2, h2 = boxes[idx]
|
||||
x2_2 = x2 + w2
|
||||
y2_2 = y2 + h2
|
||||
area2 = w2 * h2
|
||||
|
||||
# Calculate intersection
|
||||
xx1 = max(x1, x2)
|
||||
yy1 = max(y1, y2)
|
||||
xx2 = min(x2_1, x2_2)
|
||||
yy2 = min(y2_1, y2_2)
|
||||
|
||||
# Calculate intersection area
|
||||
w = max(0, xx2 - xx1)
|
||||
h = max(0, yy2 - yy1)
|
||||
intersection = w * h
|
||||
|
||||
# Calculate IoU
|
||||
union = area1 + area2 - intersection + 1e-9 # Avoid division by zero
|
||||
iou = intersection / union
|
||||
|
||||
# Remove box if IoU is above threshold
|
||||
if iou > threshold:
|
||||
indices.pop(i)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return keep
|
||||
|
||||
def close(self):
|
||||
"""Close the model resources"""
|
||||
if hasattr(self.model, "close"):
|
||||
self.model.close()
|
||||
self.model = None
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Model manager for detectors"""
|
||||
|
||||
def __init__(self):
|
||||
self.models = {}
|
||||
|
||||
def load(self, models_config: List[Dict]):
|
||||
"""
|
||||
Load models from configuration
|
||||
|
||||
Args:
|
||||
models_config: List of model configurations
|
||||
"""
|
||||
# Basic color palette for different models
|
||||
palette = [0x00FF00, 0xFF8000, 0x00A0FF, 0xFF00FF, 0x00FFFF, 0xFF0000, 0x80FF00]
|
||||
|
||||
for i, model_config in enumerate(models_config):
|
||||
name = model_config.get("name")
|
||||
path = model_config.get("path")
|
||||
size = model_config.get("size", [640, 640])
|
||||
|
||||
if not name or not path or not os.path.exists(path):
|
||||
print(f"Skipping model: {name} - Invalid configuration")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use color from palette
|
||||
color = palette[i % len(palette)]
|
||||
|
||||
# Create detector for Python model
|
||||
detector = PythonModelDetector(
|
||||
model_name=name,
|
||||
model_path=path,
|
||||
input_width=size[0],
|
||||
input_height=size[1],
|
||||
color=color
|
||||
)
|
||||
|
||||
self.models[name] = detector
|
||||
print(f"Model loaded: {name} ({path})")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model {name}: {str(e)}")
|
||||
|
||||
def get(self, name: str) -> Optional[PythonModelDetector]:
|
||||
"""Get detector by name"""
|
||||
return self.models.get(name)
|
||||
|
||||
def all(self) -> List[PythonModelDetector]:
|
||||
"""Get all detectors"""
|
||||
return list(self.models.values())
|
||||
|
||||
def close(self):
|
||||
"""Close all detectors"""
|
||||
for detector in self.models.values():
|
||||
try:
|
||||
detector.close()
|
||||
except:
|
||||
pass
|
||||
self.models.clear()
|
||||
164
python-inference-service/app/main.py
Normal file
164
python-inference-service/app/main.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import os
|
||||
import base64
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
from app.models import Detection, DetectionRequest, DetectionResponse, ModelInfo, ModelsResponse
|
||||
from app.detector import ModelManager
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Python Model Inference Service",
|
||||
description="API for object detection using Python models",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize model manager
|
||||
model_manager = None
|
||||
|
||||
# Load models from configuration
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global model_manager
|
||||
model_manager = ModelManager()
|
||||
|
||||
# Look for models.json configuration file
|
||||
models_json_path = os.getenv("MODELS_JSON", os.path.join(os.path.dirname(__file__), "..", "models", "models.json"))
|
||||
|
||||
if os.path.exists(models_json_path):
|
||||
try:
|
||||
with open(models_json_path, "r") as f:
|
||||
models_config = json.load(f)
|
||||
model_manager.load(models_config)
|
||||
print(f"Loaded model configuration from {models_json_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to load models from {models_json_path}: {str(e)}")
|
||||
else:
|
||||
print(f"Models configuration not found: {models_json_path}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global model_manager
|
||||
if model_manager:
|
||||
model_manager.close()
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/models", response_model=ModelsResponse)
|
||||
async def get_models():
|
||||
"""Get available models"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
detectors = model_manager.all()
|
||||
models = []
|
||||
|
||||
for detector in detectors:
|
||||
model_info = ModelInfo(
|
||||
name=detector.model_name,
|
||||
path=getattr(detector, 'model_path', ''),
|
||||
size=[detector.input_width, detector.input_height],
|
||||
backend="Python",
|
||||
loaded=True
|
||||
)
|
||||
models.append(model_info)
|
||||
|
||||
return ModelsResponse(models=models)
|
||||
|
||||
@app.post("/api/detect", response_model=DetectionResponse)
|
||||
async def detect(request: DetectionRequest):
|
||||
"""Detect objects in an image"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
# Get detector for requested model
|
||||
detector = model_manager.get(request.model_name)
|
||||
if not detector:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {request.model_name}")
|
||||
|
||||
# Decode base64 image
|
||||
try:
|
||||
# Remove data URL prefix if present
|
||||
if "base64," in request.image_data:
|
||||
image_data = request.image_data.split("base64,")[1]
|
||||
else:
|
||||
image_data = request.image_data
|
||||
|
||||
# Decode base64 image
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to decode image: {str(e)}")
|
||||
|
||||
# Run detection
|
||||
detections, inference_time = detector.detect(image)
|
||||
|
||||
return DetectionResponse(
|
||||
model_name=request.model_name,
|
||||
detections=detections,
|
||||
inference_time=inference_time
|
||||
)
|
||||
|
||||
@app.post("/api/detect/file", response_model=DetectionResponse)
|
||||
async def detect_file(
|
||||
model_name: str,
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""Detect objects in an uploaded image file"""
|
||||
global model_manager
|
||||
|
||||
if not model_manager:
|
||||
raise HTTPException(status_code=500, detail="Model manager not initialized")
|
||||
|
||||
# Get detector for requested model
|
||||
detector = model_manager.get(model_name)
|
||||
if not detector:
|
||||
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
||||
|
||||
# Read uploaded file
|
||||
try:
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid image data")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to process image: {str(e)}")
|
||||
|
||||
# Run detection
|
||||
detections, inference_time = detector.detect(image)
|
||||
|
||||
return DetectionResponse(
|
||||
model_name=model_name,
|
||||
detections=detections,
|
||||
inference_time=inference_time
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
40
python-inference-service/app/models.py
Normal file
40
python-inference-service/app/models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Detection(BaseModel):
|
||||
"""Object detection result"""
|
||||
label: str
|
||||
confidence: float
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
color: int = 0x00FF00 # Default green color
|
||||
|
||||
|
||||
class DetectionRequest(BaseModel):
|
||||
"""Request for model inference on image data"""
|
||||
model_name: str
|
||||
image_data: str # Base64 encoded image
|
||||
|
||||
|
||||
class DetectionResponse(BaseModel):
|
||||
"""Response with detection results"""
|
||||
model_name: str
|
||||
detections: List[Detection]
|
||||
inference_time: float # Time in milliseconds
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model information"""
|
||||
name: str
|
||||
path: str
|
||||
size: List[int] # [width, height]
|
||||
backend: str = "ONNX"
|
||||
loaded: bool = False
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""Response with available models"""
|
||||
models: List[ModelInfo]
|
||||
BIN
python-inference-service/models/best.pt
Normal file
BIN
python-inference-service/models/best.pt
Normal file
Binary file not shown.
1
python-inference-service/models/classes.txt
Normal file
1
python-inference-service/models/classes.txt
Normal file
@@ -0,0 +1 @@
|
||||
垃圾
|
||||
207
python-inference-service/models/garbage_model.py
Normal file
207
python-inference-service/models/garbage_model.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
class Model:
|
||||
"""
|
||||
垃圾识别模型 - 直接加载 PyTorch 模型文件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化模型"""
|
||||
# 获取当前文件所在目录路径
|
||||
model_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 模型文件路径
|
||||
model_path = os.path.join(model_dir, "best.pt")
|
||||
|
||||
print(f"正在加载垃圾识别模型: {model_path}")
|
||||
|
||||
# 加载 PyTorch 模型
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {self.device}")
|
||||
|
||||
# 使用 YOLOv5 或通用方式加载模型
|
||||
try:
|
||||
# 尝试使用 YOLOv5 加载
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(model_dir)) # 添加父目录到路径
|
||||
|
||||
try:
|
||||
# 方法1: 如果安装了 YOLOv5
|
||||
import yolov5
|
||||
self.model = yolov5.load(model_path, device=self.device)
|
||||
self.yolov5_api = True
|
||||
print("使用 YOLOv5 包加载模型")
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
# 方法2: 直接加载 YOLO 代码
|
||||
from models.yolov5_utils import attempt_load
|
||||
self.model = attempt_load(model_path, device=self.device)
|
||||
self.yolov5_api = False
|
||||
print("使用内置 YOLOv5 工具加载模型")
|
||||
|
||||
except Exception as e:
|
||||
# 方法3: 通用 PyTorch 加载
|
||||
print(f"YOLOv5 加载失败: {e}")
|
||||
print("使用通用 PyTorch 加载")
|
||||
self.model = torch.load(model_path, map_location=self.device)
|
||||
if isinstance(self.model, dict) and 'model' in self.model:
|
||||
self.model = self.model['model']
|
||||
self.yolov5_api = False
|
||||
|
||||
# 如果是 ScriptModule,设置为评估模式
|
||||
if isinstance(self.model, torch.jit.ScriptModule):
|
||||
self.model.eval()
|
||||
elif hasattr(self.model, 'eval'):
|
||||
self.model.eval()
|
||||
|
||||
# 加载类别名称
|
||||
self.classes = []
|
||||
classes_path = os.path.join(model_dir, "classes.txt")
|
||||
if os.path.exists(classes_path):
|
||||
with open(classes_path, 'r', encoding='utf-8') as f:
|
||||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||||
print(f"已加载 {len(self.classes)} 个类别")
|
||||
else:
|
||||
# 如果模型自带类别信息
|
||||
if hasattr(self.model, 'names') and self.model.names:
|
||||
self.classes = self.model.names
|
||||
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
|
||||
else:
|
||||
print("未找到类别文件,将使用数字索引作为类别名")
|
||||
|
||||
# 设置识别参数
|
||||
self.conf_threshold = 0.25 # 置信度阈值
|
||||
self.img_size = 640 # 默认输入图像大小
|
||||
|
||||
print("垃圾识别模型加载完成")
|
||||
|
||||
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
||||
"""预处理图像"""
|
||||
# 如果是使用 YOLOv5 API,不需要预处理
|
||||
if hasattr(self, 'yolov5_api') and self.yolov5_api:
|
||||
return image
|
||||
|
||||
# 默认预处理:调整大小并归一化
|
||||
img = cv2.resize(image, (self.img_size, self.img_size))
|
||||
|
||||
# BGR 转 RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 归一化 [0, 255] -> [0, 1]
|
||||
img = img / 255.0
|
||||
|
||||
# HWC -> CHW (高度,宽度,通道 -> 通道,高度,宽度)
|
||||
img = img.transpose(2, 0, 1)
|
||||
|
||||
# 转为 torch tensor
|
||||
img = torch.from_numpy(img).float()
|
||||
|
||||
# 添加批次维度
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# 移至设备
|
||||
img = img.to(self.device)
|
||||
|
||||
return img
|
||||
|
||||
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""模型推理"""
|
||||
original_height, original_width = image.shape[:2]
|
||||
|
||||
try:
|
||||
# 如果使用 YOLOv5 API
|
||||
if hasattr(self, 'yolov5_api') and self.yolov5_api:
|
||||
# YOLOv5 API 直接处理图像
|
||||
results = self.model(image)
|
||||
|
||||
# 提取检测结果
|
||||
predictions = results.pred[0] # 第一批次的预测
|
||||
|
||||
detections = []
|
||||
for *xyxy, conf, cls_id in predictions.cpu().numpy():
|
||||
x1, y1, x2, y2 = xyxy
|
||||
|
||||
# 转换为归一化坐标 (x, y, w, h)
|
||||
x = x1 / original_width
|
||||
y = y1 / original_height
|
||||
w = (x2 - x1) / original_width
|
||||
h = (y2 - y1) / original_height
|
||||
|
||||
# 整数类别 ID
|
||||
cls_id = int(cls_id)
|
||||
|
||||
# 获取类别名称
|
||||
class_name = f"cls{cls_id}"
|
||||
if 0 <= cls_id < len(self.classes):
|
||||
class_name = self.classes[cls_id]
|
||||
|
||||
# 添加检测结果
|
||||
if conf >= self.conf_threshold:
|
||||
detections.append({
|
||||
'bbox': (x, y, w, h),
|
||||
'class_id': cls_id,
|
||||
'confidence': float(conf)
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
else:
|
||||
# 通用 PyTorch 模型处理
|
||||
# 预处理图像
|
||||
img = self.preprocess(image)
|
||||
|
||||
# 推理
|
||||
with torch.no_grad():
|
||||
outputs = self.model(img)
|
||||
|
||||
# 后处理结果(这里需要根据模型输出格式调整)
|
||||
detections = []
|
||||
|
||||
# 假设输出格式是 YOLO 风格:[batch_idx, x1, y1, x2, y2, conf, cls_id]
|
||||
if isinstance(outputs, torch.Tensor) and outputs.dim() == 2 and outputs.size(1) >= 6:
|
||||
for *xyxy, conf, cls_id in outputs.cpu().numpy():
|
||||
if conf >= self.conf_threshold:
|
||||
x1, y1, x2, y2 = xyxy
|
||||
|
||||
# 转换为归一化坐标 (x, y, w, h)
|
||||
x = x1 / original_width
|
||||
y = y1 / original_height
|
||||
w = (x2 - x1) / original_width
|
||||
h = (y2 - y1) / original_height
|
||||
|
||||
# 整数类别 ID
|
||||
cls_id = int(cls_id)
|
||||
|
||||
detections.append({
|
||||
'bbox': (x, y, w, h),
|
||||
'class_id': cls_id,
|
||||
'confidence': float(conf)
|
||||
})
|
||||
# 处理其他可能的输出格式
|
||||
else:
|
||||
# 这里需要根据模型的实际输出格式进行适配
|
||||
print("警告:无法识别的模型输出格式,请检查模型类型")
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
print(f"推理过程中出错: {str(e)}")
|
||||
# 出错时返回空结果
|
||||
return []
|
||||
|
||||
@property
|
||||
def applies_nms(self) -> bool:
|
||||
"""模型是否内部应用了 NMS"""
|
||||
# YOLOv5 会自动应用 NMS
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""释放资源"""
|
||||
if hasattr(self, 'model'):
|
||||
# 删除模型以释放 GPU 内存
|
||||
del self.model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("垃圾识别模型已关闭")
|
||||
8
python-inference-service/models/models.json
Normal file
8
python-inference-service/models/models.json
Normal file
@@ -0,0 +1,8 @@
|
||||
[
|
||||
{
|
||||
"name": "yolov8_detector",
|
||||
"path": "models/yolov8_model.py",
|
||||
"size": [640, 640],
|
||||
"comment": "YOLOv8检测模型,确保将训练好的best.pt文件放在models目录下"
|
||||
}
|
||||
]
|
||||
126
python-inference-service/models/smoke_detector.py
Normal file
126
python-inference-service/models/smoke_detector.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Smoke detection model implementation
|
||||
|
||||
This is a simple example that could be replaced with an actual
|
||||
TensorFlow, PyTorch, or other ML framework implementation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize smoke detection model"""
|
||||
# In a real implementation, you would load your model here
|
||||
print("Smoke detection model initialized")
|
||||
|
||||
# Define smoke class IDs
|
||||
self.smoke_classes = {
|
||||
0: "smoke",
|
||||
1: "fire"
|
||||
}
|
||||
|
||||
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess image for model input"""
|
||||
# Convert BGR to grayscale for smoke detection
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
# Convert back to 3 channels to match model expected input shape
|
||||
gray_3ch = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
# In a real implementation, you would do normalization, etc.
|
||||
return gray_3ch
|
||||
|
||||
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run smoke detection on the image
|
||||
|
||||
This is a simplified example that uses basic image processing
|
||||
In a real implementation, you would use your ML model
|
||||
"""
|
||||
# Convert to grayscale for processing
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply Gaussian blur to reduce noise
|
||||
blurred = cv2.GaussianBlur(gray, (15, 15), 0)
|
||||
|
||||
# Simple thresholding to find potential smoke regions
|
||||
# In a real implementation, you'd use a trained model
|
||||
_, thresh = cv2.threshold(blurred, 100, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Find contours in the thresholded image
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Process contours to find potential smoke regions
|
||||
detections = []
|
||||
height, width = image.shape[:2]
|
||||
|
||||
for contour in contours:
|
||||
# Get bounding box
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Filter small regions
|
||||
if w > width * 0.05 and h > height * 0.05:
|
||||
# Calculate area ratio
|
||||
area = cv2.contourArea(contour)
|
||||
rect_area = w * h
|
||||
fill_ratio = area / rect_area if rect_area > 0 else 0
|
||||
|
||||
# Smoke tends to have irregular shapes
|
||||
# This is just for demonstration purposes
|
||||
if fill_ratio > 0.2 and fill_ratio < 0.8:
|
||||
# Normalize coordinates
|
||||
x_norm = x / width
|
||||
y_norm = y / height
|
||||
w_norm = w / width
|
||||
h_norm = h / height
|
||||
|
||||
# Determine if it's smoke or fire (just a simple heuristic for demo)
|
||||
# In a real model, this would be determined by the model prediction
|
||||
class_id = 0 # Default to smoke
|
||||
|
||||
# Check if the region has high red values (fire)
|
||||
roi = image[y:y+h, x:x+w]
|
||||
if roi.size > 0: # Make sure ROI is not empty
|
||||
avg_color = np.mean(roi, axis=(0, 1))
|
||||
if avg_color[2] > 150 and avg_color[2] > avg_color[0] * 1.5: # High red, indicating fire
|
||||
class_id = 1 # Fire
|
||||
|
||||
# Calculate confidence based on fill ratio
|
||||
# This is just for demonstration
|
||||
confidence = 0.5 + fill_ratio * 0.3
|
||||
|
||||
# Add to detections
|
||||
detections.append({
|
||||
'bbox': (x_norm, y_norm, w_norm, h_norm),
|
||||
'class_id': class_id,
|
||||
'confidence': confidence
|
||||
})
|
||||
|
||||
# For demo purposes, if no smoke detected by algorithm,
|
||||
# add a small chance of random detection
|
||||
if not detections and np.random.random() < 0.1: # 10% chance
|
||||
# Random smoke detection
|
||||
x = np.random.random() * 0.7
|
||||
y = np.random.random() * 0.7
|
||||
w = 0.1 + np.random.random() * 0.2
|
||||
h = 0.1 + np.random.random() * 0.2
|
||||
confidence = 0.5 + np.random.random() * 0.3
|
||||
|
||||
detections.append({
|
||||
'bbox': (x, y, w, h),
|
||||
'class_id': 0, # Smoke
|
||||
'confidence': confidence
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
@property
|
||||
def applies_nms(self) -> bool:
|
||||
"""Model does not apply NMS internally"""
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""Release resources"""
|
||||
# In a real implementation, you would release model resources here
|
||||
pass
|
||||
56
python-inference-service/models/yolov5_utils.py
Normal file
56
python-inference-service/models/yolov5_utils.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import os
|
||||
|
||||
def attempt_load(weights, device=''):
|
||||
"""尝试加载YOLOv5模型"""
|
||||
# 加载模型
|
||||
model = torch.load(weights, map_location=device)
|
||||
|
||||
# 确定模型格式
|
||||
if isinstance(model, dict):
|
||||
if 'model' in model: # state_dict格式
|
||||
model = model['model']
|
||||
elif 'state_dict' in model: # state_dict格式
|
||||
model = model['state_dict']
|
||||
|
||||
# 如果是state_dict,则需要创建模型架构
|
||||
if isinstance(model, dict):
|
||||
print("警告:加载的是权重字典,尝试创建默认模型结构")
|
||||
from models.yolov5_model import YOLOv5
|
||||
model_arch = YOLOv5()
|
||||
model_arch.load_state_dict(model)
|
||||
model = model_arch
|
||||
|
||||
# 设置为评估模式
|
||||
if isinstance(model, nn.Module):
|
||||
model.eval()
|
||||
|
||||
# 检查是否有类别信息
|
||||
if not hasattr(model, 'names') or not model.names:
|
||||
print("模型没有类别信息,尝试加载默认类别")
|
||||
# 设置通用类别
|
||||
model.names = ['object']
|
||||
|
||||
return model
|
||||
|
||||
class YOLOv5:
|
||||
"""简化版YOLOv5模型结构,用于加载权重"""
|
||||
def __init__(self):
|
||||
super(YOLOv5, self).__init__()
|
||||
self.names = [] # 类别名称
|
||||
# 这里应该添加真实的网络结构
|
||||
# 但为了简单起见,我们只提供一个占位符
|
||||
# 在实际使用中,您应该实现完整的网络架构
|
||||
|
||||
def forward(self, x):
|
||||
# 这里应该是实际的前向传播逻辑
|
||||
# 这只是一个占位符
|
||||
raise NotImplementedError("这是一个占位符模型,请使用完整的YOLOv5模型实现")
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
print("尝试加载模型权重")
|
||||
# 实际的权重加载逻辑
|
||||
# 这只是一个占位符
|
||||
return self
|
||||
135
python-inference-service/models/yolov8_model.py
Normal file
135
python-inference-service/models/yolov8_model.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import List, Dict, Any
|
||||
import torch
|
||||
|
||||
class Model:
|
||||
"""
|
||||
YOLOv8 模型包装类 - 使用 Ultralytics YOLO
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化YOLOv8模型"""
|
||||
# 获取当前文件所在目录路径
|
||||
model_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 模型文件路径
|
||||
model_path = os.path.join(model_dir, "best.pt")
|
||||
|
||||
print(f"正在加载YOLOv8模型: {model_path}")
|
||||
|
||||
# 检查设备
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {self.device}")
|
||||
|
||||
# 使用 Ultralytics YOLO 加载模型
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
self.model = YOLO(model_path)
|
||||
print("使用 Ultralytics YOLO 加载模型成功")
|
||||
except ImportError:
|
||||
raise ImportError("请安装 ultralytics: pip install ultralytics>=8.0.0")
|
||||
except Exception as e:
|
||||
raise Exception(f"加载YOLOv8模型失败: {str(e)}")
|
||||
|
||||
# 加载类别名称
|
||||
self.classes = []
|
||||
classes_path = os.path.join(model_dir, "classes.txt")
|
||||
if os.path.exists(classes_path):
|
||||
with open(classes_path, 'r', encoding='utf-8') as f:
|
||||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||||
print(f"已加载 {len(self.classes)} 个类别")
|
||||
else:
|
||||
# 使用模型自带的类别信息
|
||||
if hasattr(self.model, 'names') and self.model.names:
|
||||
self.classes = list(self.model.names.values()) if isinstance(self.model.names, dict) else self.model.names
|
||||
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
|
||||
else:
|
||||
print("未找到类别文件,将使用数字索引作为类别名")
|
||||
|
||||
# 设置识别参数
|
||||
self.conf_threshold = 0.25 # 置信度阈值
|
||||
self.img_size = 640 # 默认输入图像大小
|
||||
|
||||
print("YOLOv8模型加载完成")
|
||||
|
||||
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
||||
"""预处理图像 - YOLOv8会自动处理,这里直接返回"""
|
||||
return image
|
||||
|
||||
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""模型推理"""
|
||||
original_height, original_width = image.shape[:2]
|
||||
|
||||
try:
|
||||
# YOLOv8推理
|
||||
results = self.model(
|
||||
image,
|
||||
conf=self.conf_threshold,
|
||||
device=self.device,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
detections = []
|
||||
|
||||
# 解析结果
|
||||
for result in results:
|
||||
# 获取检测框
|
||||
boxes = result.boxes
|
||||
|
||||
if boxes is None or len(boxes) == 0:
|
||||
continue
|
||||
|
||||
# 遍历每个检测框
|
||||
for box in boxes:
|
||||
# 获取坐标 (xyxy格式)
|
||||
xyxy = box.xyxy[0].cpu().numpy()
|
||||
x1, y1, x2, y2 = xyxy
|
||||
|
||||
# 转换为归一化坐标 (x, y, w, h)
|
||||
x = x1 / original_width
|
||||
y = y1 / original_height
|
||||
w = (x2 - x1) / original_width
|
||||
h = (y2 - y1) / original_height
|
||||
|
||||
# 获取置信度
|
||||
conf = float(box.conf[0].cpu().numpy())
|
||||
|
||||
# 获取类别ID
|
||||
cls_id = int(box.cls[0].cpu().numpy())
|
||||
|
||||
# 获取类别名称
|
||||
class_name = f"cls{cls_id}"
|
||||
if 0 <= cls_id < len(self.classes):
|
||||
class_name = self.classes[cls_id]
|
||||
|
||||
# 添加检测结果
|
||||
if conf >= self.conf_threshold:
|
||||
detections.append({
|
||||
'bbox': (x, y, w, h),
|
||||
'class_id': cls_id,
|
||||
'confidence': conf
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
print(f"推理过程中出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
@property
|
||||
def applies_nms(self) -> bool:
|
||||
"""模型是否内部应用了 NMS"""
|
||||
# YOLOv8会自动应用 NMS
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""释放资源"""
|
||||
if hasattr(self, 'model'):
|
||||
# 删除模型以释放 GPU 内存
|
||||
del self.model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("YOLOv8模型已关闭")
|
||||
10
python-inference-service/requirements.txt
Normal file
10
python-inference-service/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
fastapi==0.103.1
|
||||
uvicorn==0.23.2
|
||||
opencv-python==4.8.0.76
|
||||
numpy==1.25.2
|
||||
pydantic==2.3.0
|
||||
python-multipart==0.0.6
|
||||
minio==7.1.15
|
||||
torch>=1.7.0
|
||||
torchvision>=0.8.1
|
||||
ultralytics>=8.0.0
|
||||
5
python-inference-service/start_service.bat
Normal file
5
python-inference-service/start_service.bat
Normal file
@@ -0,0 +1,5 @@
|
||||
@echo off
|
||||
echo Starting Python Inference Service...
|
||||
cd /d %~dp0
|
||||
python -m app.main
|
||||
pause
|
||||
4
python-inference-service/start_service.sh
Normal file
4
python-inference-service/start_service.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
echo "Starting Python Inference Service..."
|
||||
cd "$(dirname "$0")"
|
||||
python -m app.main
|
||||
Reference in New Issue
Block a user