Files
rtsp-video-analysis-system/python-inference-service/models/universal_yolo_model.py
2025-10-08 11:51:28 +08:00

170 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import cv2
from typing import List, Dict, Any
import torch
class Model:
"""
通用 YOLO 模型 - 支持 YOLOv8/YOLOv11 等基于 Ultralytics 的模型
"""
def __init__(self, model_file: str = None, model_name: str = "YOLO"):
"""
初始化模型
Args:
model_file: 模型文件名(如 smoke.pt, best.pt
model_name: 模型显示名称(用于日志)
"""
# 获取当前文件所在目录路径
model_dir = os.path.dirname(os.path.abspath(__file__))
# 如果没有指定模型文件,尝试常见的文件名
if model_file is None:
for possible_file in ['garbage.pt', 'smoke.pt', 'best.pt', 'yolov8.pt', 'model.pt']:
test_path = os.path.join(model_dir, possible_file)
if os.path.exists(test_path):
model_file = possible_file
break
if model_file is None:
raise FileNotFoundError(f"未找到模型文件,请在初始化时指定 model_file 参数")
# 模型文件路径
model_path = os.path.join(model_dir, model_file)
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
self.model_name = model_name
print(f"正在加载{model_name}模型: {model_path}")
# 检查设备
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {self.device}")
# 使用 Ultralytics YOLO 加载模型
try:
from ultralytics import YOLO
self.model = YOLO(model_path)
print(f"使用 Ultralytics YOLO 加载模型成功")
except ImportError:
raise ImportError("请安装 ultralytics: pip install ultralytics>=8.0.0")
except Exception as e:
raise Exception(f"加载{model_name}模型失败: {str(e)}")
# 加载类别名称
self.classes = []
# 1. 首先尝试加载与模型文件同名的类别文件(如 smoke.txt
model_base_name = os.path.splitext(model_file)[0]
classes_path_specific = os.path.join(model_dir, f"{model_base_name}.txt")
# 2. 然后尝试加载通用的 classes.txt
classes_path_generic = os.path.join(model_dir, "classes.txt")
if os.path.exists(classes_path_specific):
with open(classes_path_specific, 'r', encoding='utf-8') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
print(f"已加载类别文件: {model_base_name}.txt ({len(self.classes)} 个类别)")
elif os.path.exists(classes_path_generic):
with open(classes_path_generic, 'r', encoding='utf-8') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
print(f"已加载类别文件: classes.txt ({len(self.classes)} 个类别)")
else:
# 使用模型自带的类别信息
if hasattr(self.model, 'names') and self.model.names:
self.classes = list(self.model.names.values()) if isinstance(self.model.names, dict) else self.model.names
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
else:
print("未找到类别文件,将使用数字索引作为类别名")
# 设置识别参数
self.conf_threshold = 0.25 # 置信度阈值
self.img_size = 640 # 默认输入图像大小
print(f"{model_name}模型加载完成")
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""预处理图像 - Ultralytics YOLO 会自动处理,这里直接返回"""
return image
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""模型推理"""
original_height, original_width = image.shape[:2]
try:
# YOLO 推理
results = self.model(
image,
conf=self.conf_threshold,
device=self.device,
verbose=False
)
detections = []
# 解析结果
for result in results:
# 获取检测框
boxes = result.boxes
if boxes is None or len(boxes) == 0:
continue
# 遍历每个检测框
for box in boxes:
# 获取坐标 (xyxy格式)
xyxy = box.xyxy[0].cpu().numpy()
x1, y1, x2, y2 = xyxy
# 转换为归一化坐标 (x, y, w, h)
x = x1 / original_width
y = y1 / original_height
w = (x2 - x1) / original_width
h = (y2 - y1) / original_height
# 获取置信度
conf = float(box.conf[0].cpu().numpy())
# 获取类别ID
cls_id = int(box.cls[0].cpu().numpy())
# 获取类别名称
class_name = f"cls{cls_id}"
if 0 <= cls_id < len(self.classes):
class_name = self.classes[cls_id]
# 添加检测结果
if conf >= self.conf_threshold:
detections.append({
'bbox': (x, y, w, h),
'class_id': cls_id,
'confidence': conf
})
return detections
except Exception as e:
print(f"推理过程中出错: {str(e)}")
import traceback
traceback.print_exc()
return []
@property
def applies_nms(self) -> bool:
"""模型是否内部应用了 NMS"""
# Ultralytics YOLO 会自动应用 NMS
return True
def close(self):
"""释放资源"""
if hasattr(self, 'model'):
# 删除模型以释放 GPU 内存
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"{self.model_name}模型已关闭")