Files
rtsp-video-analysis-system/python-inference-service/models/yolov8_model.py
2025-09-30 14:23:33 +08:00

135 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import cv2
from typing import List, Dict, Any
import torch
class Model:
"""
YOLOv8 模型包装类 - 使用 Ultralytics YOLO
"""
def __init__(self):
"""初始化YOLOv8模型"""
# 获取当前文件所在目录路径
model_dir = os.path.dirname(os.path.abspath(__file__))
# 模型文件路径
model_path = os.path.join(model_dir, "best.pt")
print(f"正在加载YOLOv8模型: {model_path}")
# 检查设备
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {self.device}")
# 使用 Ultralytics YOLO 加载模型
try:
from ultralytics import YOLO
self.model = YOLO(model_path)
print("使用 Ultralytics YOLO 加载模型成功")
except ImportError:
raise ImportError("请安装 ultralytics: pip install ultralytics>=8.0.0")
except Exception as e:
raise Exception(f"加载YOLOv8模型失败: {str(e)}")
# 加载类别名称
self.classes = []
classes_path = os.path.join(model_dir, "classes.txt")
if os.path.exists(classes_path):
with open(classes_path, 'r', encoding='utf-8') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
print(f"已加载 {len(self.classes)} 个类别")
else:
# 使用模型自带的类别信息
if hasattr(self.model, 'names') and self.model.names:
self.classes = list(self.model.names.values()) if isinstance(self.model.names, dict) else self.model.names
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
else:
print("未找到类别文件,将使用数字索引作为类别名")
# 设置识别参数
self.conf_threshold = 0.25 # 置信度阈值
self.img_size = 640 # 默认输入图像大小
print("YOLOv8模型加载完成")
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""预处理图像 - YOLOv8会自动处理这里直接返回"""
return image
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""模型推理"""
original_height, original_width = image.shape[:2]
try:
# YOLOv8推理
results = self.model(
image,
conf=self.conf_threshold,
device=self.device,
verbose=False
)
detections = []
# 解析结果
for result in results:
# 获取检测框
boxes = result.boxes
if boxes is None or len(boxes) == 0:
continue
# 遍历每个检测框
for box in boxes:
# 获取坐标 (xyxy格式)
xyxy = box.xyxy[0].cpu().numpy()
x1, y1, x2, y2 = xyxy
# 转换为归一化坐标 (x, y, w, h)
x = x1 / original_width
y = y1 / original_height
w = (x2 - x1) / original_width
h = (y2 - y1) / original_height
# 获取置信度
conf = float(box.conf[0].cpu().numpy())
# 获取类别ID
cls_id = int(box.cls[0].cpu().numpy())
# 获取类别名称
class_name = f"cls{cls_id}"
if 0 <= cls_id < len(self.classes):
class_name = self.classes[cls_id]
# 添加检测结果
if conf >= self.conf_threshold:
detections.append({
'bbox': (x, y, w, h),
'class_id': cls_id,
'confidence': conf
})
return detections
except Exception as e:
print(f"推理过程中出错: {str(e)}")
import traceback
traceback.print_exc()
return []
@property
def applies_nms(self) -> bool:
"""模型是否内部应用了 NMS"""
# YOLOv8会自动应用 NMS
return True
def close(self):
"""释放资源"""
if hasattr(self, 'model'):
# 删除模型以释放 GPU 内存
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("YOLOv8模型已关闭")