2025-09-30 14:23:33 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import cv2
|
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
|
|
"""
|
2025-10-08 11:51:28 +08:00
|
|
|
|
通用 YOLO 模型 - 支持 YOLOv8/YOLOv11 等基于 Ultralytics 的模型
|
2025-09-30 14:23:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2025-10-08 11:51:28 +08:00
|
|
|
|
def __init__(self, model_file: str = None, model_name: str = "YOLO"):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化模型
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
model_file: 模型文件名(如 smoke.pt, best.pt)
|
|
|
|
|
|
model_name: 模型显示名称(用于日志)
|
|
|
|
|
|
"""
|
2025-09-30 14:23:33 +08:00
|
|
|
|
# 获取当前文件所在目录路径
|
|
|
|
|
|
model_dir = os.path.dirname(os.path.abspath(__file__))
|
2025-10-08 11:51:28 +08:00
|
|
|
|
|
|
|
|
|
|
# 如果没有指定模型文件,尝试常见的文件名
|
|
|
|
|
|
if model_file is None:
|
|
|
|
|
|
for possible_file in ['garbage.pt', 'smoke.pt', 'best.pt', 'yolov8.pt', 'model.pt']:
|
|
|
|
|
|
test_path = os.path.join(model_dir, possible_file)
|
|
|
|
|
|
if os.path.exists(test_path):
|
|
|
|
|
|
model_file = possible_file
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if model_file is None:
|
|
|
|
|
|
raise FileNotFoundError(f"未找到模型文件,请在初始化时指定 model_file 参数")
|
|
|
|
|
|
|
2025-09-30 14:23:33 +08:00
|
|
|
|
# 模型文件路径
|
2025-10-08 11:51:28 +08:00
|
|
|
|
model_path = os.path.join(model_dir, model_file)
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
|
|
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
|
2025-10-08 11:51:28 +08:00
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
print(f"正在加载{model_name}模型: {model_path}")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查设备
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
print(f"使用设备: {self.device}")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 Ultralytics YOLO 加载模型
|
|
|
|
|
|
try:
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
self.model = YOLO(model_path)
|
2025-10-08 11:51:28 +08:00
|
|
|
|
print(f"使用 Ultralytics YOLO 加载模型成功")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
except ImportError:
|
|
|
|
|
|
raise ImportError("请安装 ultralytics: pip install ultralytics>=8.0.0")
|
|
|
|
|
|
except Exception as e:
|
2025-10-08 11:51:28 +08:00
|
|
|
|
raise Exception(f"加载{model_name}模型失败: {str(e)}")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
|
|
|
|
|
|
# 加载类别名称
|
|
|
|
|
|
self.classes = []
|
2025-10-08 11:51:28 +08:00
|
|
|
|
|
|
|
|
|
|
# 1. 首先尝试加载与模型文件同名的类别文件(如 smoke.txt)
|
|
|
|
|
|
model_base_name = os.path.splitext(model_file)[0]
|
|
|
|
|
|
classes_path_specific = os.path.join(model_dir, f"{model_base_name}.txt")
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 然后尝试加载通用的 classes.txt
|
|
|
|
|
|
classes_path_generic = os.path.join(model_dir, "classes.txt")
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(classes_path_specific):
|
|
|
|
|
|
with open(classes_path_specific, 'r', encoding='utf-8') as f:
|
2025-09-30 14:23:33 +08:00
|
|
|
|
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
2025-10-08 11:51:28 +08:00
|
|
|
|
print(f"已加载类别文件: {model_base_name}.txt ({len(self.classes)} 个类别)")
|
|
|
|
|
|
elif os.path.exists(classes_path_generic):
|
|
|
|
|
|
with open(classes_path_generic, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
|
|
|
|
|
print(f"已加载类别文件: classes.txt ({len(self.classes)} 个类别)")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
else:
|
|
|
|
|
|
# 使用模型自带的类别信息
|
|
|
|
|
|
if hasattr(self.model, 'names') and self.model.names:
|
|
|
|
|
|
self.classes = list(self.model.names.values()) if isinstance(self.model.names, dict) else self.model.names
|
|
|
|
|
|
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("未找到类别文件,将使用数字索引作为类别名")
|
|
|
|
|
|
|
|
|
|
|
|
# 设置识别参数
|
|
|
|
|
|
self.conf_threshold = 0.25 # 置信度阈值
|
|
|
|
|
|
self.img_size = 640 # 默认输入图像大小
|
|
|
|
|
|
|
2025-10-08 11:51:28 +08:00
|
|
|
|
print(f"{model_name}模型加载完成")
|
2025-09-30 14:23:33 +08:00
|
|
|
|
|
|
|
|
|
|
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
2025-10-08 11:51:28 +08:00
|
|
|
|
"""预处理图像 - Ultralytics YOLO 会自动处理,这里直接返回"""
|
2025-09-30 14:23:33 +08:00
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
|
|
|
|
"""模型推理"""
|
|
|
|
|
|
original_height, original_width = image.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2025-10-08 11:51:28 +08:00
|
|
|
|
# YOLO 推理
|
2025-09-30 14:23:33 +08:00
|
|
|
|
results = self.model(
|
|
|
|
|
|
image,
|
|
|
|
|
|
conf=self.conf_threshold,
|
|
|
|
|
|
device=self.device,
|
|
|
|
|
|
verbose=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
detections = []
|
|
|
|
|
|
|
|
|
|
|
|
# 解析结果
|
|
|
|
|
|
for result in results:
|
|
|
|
|
|
# 获取检测框
|
|
|
|
|
|
boxes = result.boxes
|
|
|
|
|
|
|
|
|
|
|
|
if boxes is None or len(boxes) == 0:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 遍历每个检测框
|
|
|
|
|
|
for box in boxes:
|
|
|
|
|
|
# 获取坐标 (xyxy格式)
|
|
|
|
|
|
xyxy = box.xyxy[0].cpu().numpy()
|
|
|
|
|
|
x1, y1, x2, y2 = xyxy
|
|
|
|
|
|
|
|
|
|
|
|
# 转换为归一化坐标 (x, y, w, h)
|
|
|
|
|
|
x = x1 / original_width
|
|
|
|
|
|
y = y1 / original_height
|
|
|
|
|
|
w = (x2 - x1) / original_width
|
|
|
|
|
|
h = (y2 - y1) / original_height
|
|
|
|
|
|
|
|
|
|
|
|
# 获取置信度
|
|
|
|
|
|
conf = float(box.conf[0].cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
# 获取类别ID
|
|
|
|
|
|
cls_id = int(box.cls[0].cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
# 获取类别名称
|
|
|
|
|
|
class_name = f"cls{cls_id}"
|
|
|
|
|
|
if 0 <= cls_id < len(self.classes):
|
|
|
|
|
|
class_name = self.classes[cls_id]
|
|
|
|
|
|
|
|
|
|
|
|
# 添加检测结果
|
|
|
|
|
|
if conf >= self.conf_threshold:
|
|
|
|
|
|
detections.append({
|
|
|
|
|
|
'bbox': (x, y, w, h),
|
|
|
|
|
|
'class_id': cls_id,
|
|
|
|
|
|
'confidence': conf
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return detections
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"推理过程中出错: {str(e)}")
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def applies_nms(self) -> bool:
|
|
|
|
|
|
"""模型是否内部应用了 NMS"""
|
2025-10-08 11:51:28 +08:00
|
|
|
|
# Ultralytics YOLO 会自动应用 NMS
|
2025-09-30 14:23:33 +08:00
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
|
|
"""释放资源"""
|
|
|
|
|
|
if hasattr(self, 'model'):
|
|
|
|
|
|
# 删除模型以释放 GPU 内存
|
|
|
|
|
|
del self.model
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
|
torch.cuda.empty_cache()
|
2025-10-08 11:51:28 +08:00
|
|
|
|
print(f"{self.model_name}模型已关闭")
|
|
|
|
|
|
|