- 在 garbage_model.py 和 smoke_model.py 中添加 weights_only=False 参数以允许加载模型类结构 - 修复 HTTP YOLO 检测器中的文件上传和响应解析逻辑- 移除不必要的导入并优化代码结构 - 添加自定义字节数组资源类以支持 RestTemplate 文件上传- 改进错误处理和日志记录机制
211 lines
8.1 KiB
Python
211 lines
8.1 KiB
Python
import os
|
||
import numpy as np
|
||
import cv2
|
||
from typing import List, Dict, Any
|
||
import torch
|
||
|
||
class Model:
|
||
"""
|
||
垃圾识别模型 - 直接加载 PyTorch 模型文件
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化模型"""
|
||
# 获取当前文件所在目录路径
|
||
model_dir = os.path.dirname(os.path.abspath(__file__))
|
||
# 模型文件路径
|
||
model_path = os.path.join(model_dir, "best.pt")
|
||
|
||
print(f"正在加载垃圾识别模型: {model_path}")
|
||
|
||
# 加载 PyTorch 模型
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"使用设备: {self.device}")
|
||
|
||
# 使用 YOLOv5 或通用方式加载模型
|
||
try:
|
||
# 尝试使用 YOLOv5 加载
|
||
import sys
|
||
sys.path.append(os.path.dirname(model_dir)) # 添加父目录到路径
|
||
|
||
try:
|
||
# 方法1: 如果安装了 YOLOv5
|
||
import yolov5
|
||
self.model = yolov5.load(model_path, device=self.device)
|
||
self.yolov5_api = True
|
||
print("使用 YOLOv5 包加载模型")
|
||
except (ImportError, ModuleNotFoundError):
|
||
# 方法2: 直接加载 YOLO 代码
|
||
from models.yolov5_utils import attempt_load
|
||
self.model = attempt_load(model_path, device=self.device)
|
||
self.yolov5_api = False
|
||
print("使用内置 YOLOv5 工具加载模型")
|
||
|
||
except Exception as e:
|
||
# 方法3: 通用 PyTorch 加载
|
||
print(f"YOLOv5 加载失败: {e}")
|
||
print("使用通用 PyTorch 加载")
|
||
self.model = torch.load(
|
||
model_path,
|
||
map_location=self.device,
|
||
weights_only=False # 允许加载模型类结构(解决 PyTorch 2.6+ 兼容性问题)
|
||
)
|
||
if isinstance(self.model, dict) and 'model' in self.model:
|
||
self.model = self.model['model']
|
||
self.yolov5_api = False
|
||
|
||
# 如果是 ScriptModule,设置为评估模式
|
||
if isinstance(self.model, torch.jit.ScriptModule):
|
||
self.model.eval()
|
||
elif hasattr(self.model, 'eval'):
|
||
self.model.eval()
|
||
|
||
# 加载类别名称
|
||
self.classes = []
|
||
classes_path = os.path.join(model_dir, "classes.txt")
|
||
if os.path.exists(classes_path):
|
||
with open(classes_path, 'r', encoding='utf-8') as f:
|
||
self.classes = [line.strip() for line in f.readlines() if line.strip()]
|
||
print(f"已加载 {len(self.classes)} 个类别")
|
||
else:
|
||
# 如果模型自带类别信息
|
||
if hasattr(self.model, 'names') and self.model.names:
|
||
self.classes = self.model.names
|
||
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
|
||
else:
|
||
print("未找到类别文件,将使用数字索引作为类别名")
|
||
|
||
# 设置识别参数
|
||
self.conf_threshold = 0.25 # 置信度阈值
|
||
self.img_size = 640 # 默认输入图像大小
|
||
|
||
print("垃圾识别模型加载完成")
|
||
|
||
def preprocess(self, image: np.ndarray) -> np.ndarray:
|
||
"""预处理图像"""
|
||
# 如果是使用 YOLOv5 API,不需要预处理
|
||
if hasattr(self, 'yolov5_api') and self.yolov5_api:
|
||
return image
|
||
|
||
# 默认预处理:调整大小并归一化
|
||
img = cv2.resize(image, (self.img_size, self.img_size))
|
||
|
||
# BGR 转 RGB
|
||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
||
# 归一化 [0, 255] -> [0, 1]
|
||
img = img / 255.0
|
||
|
||
# HWC -> CHW (高度,宽度,通道 -> 通道,高度,宽度)
|
||
img = img.transpose(2, 0, 1)
|
||
|
||
# 转为 torch tensor
|
||
img = torch.from_numpy(img).float()
|
||
|
||
# 添加批次维度
|
||
img = img.unsqueeze(0)
|
||
|
||
# 移至设备
|
||
img = img.to(self.device)
|
||
|
||
return img
|
||
|
||
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||
"""模型推理"""
|
||
original_height, original_width = image.shape[:2]
|
||
|
||
try:
|
||
# 如果使用 YOLOv5 API
|
||
if hasattr(self, 'yolov5_api') and self.yolov5_api:
|
||
# YOLOv5 API 直接处理图像
|
||
results = self.model(image)
|
||
|
||
# 提取检测结果
|
||
predictions = results.pred[0] # 第一批次的预测
|
||
|
||
detections = []
|
||
for *xyxy, conf, cls_id in predictions.cpu().numpy():
|
||
x1, y1, x2, y2 = xyxy
|
||
|
||
# 转换为归一化坐标 (x, y, w, h)
|
||
x = x1 / original_width
|
||
y = y1 / original_height
|
||
w = (x2 - x1) / original_width
|
||
h = (y2 - y1) / original_height
|
||
|
||
# 整数类别 ID
|
||
cls_id = int(cls_id)
|
||
|
||
# 获取类别名称
|
||
class_name = f"cls{cls_id}"
|
||
if 0 <= cls_id < len(self.classes):
|
||
class_name = self.classes[cls_id]
|
||
|
||
# 添加检测结果
|
||
if conf >= self.conf_threshold:
|
||
detections.append({
|
||
'bbox': (x, y, w, h),
|
||
'class_id': cls_id,
|
||
'confidence': float(conf)
|
||
})
|
||
|
||
return detections
|
||
|
||
else:
|
||
# 通用 PyTorch 模型处理
|
||
# 预处理图像
|
||
img = self.preprocess(image)
|
||
|
||
# 推理
|
||
with torch.no_grad():
|
||
outputs = self.model(img)
|
||
|
||
# 后处理结果(这里需要根据模型输出格式调整)
|
||
detections = []
|
||
|
||
# 假设输出格式是 YOLO 风格:[batch_idx, x1, y1, x2, y2, conf, cls_id]
|
||
if isinstance(outputs, torch.Tensor) and outputs.dim() == 2 and outputs.size(1) >= 6:
|
||
for *xyxy, conf, cls_id in outputs.cpu().numpy():
|
||
if conf >= self.conf_threshold:
|
||
x1, y1, x2, y2 = xyxy
|
||
|
||
# 转换为归一化坐标 (x, y, w, h)
|
||
x = x1 / original_width
|
||
y = y1 / original_height
|
||
w = (x2 - x1) / original_width
|
||
h = (y2 - y1) / original_height
|
||
|
||
# 整数类别 ID
|
||
cls_id = int(cls_id)
|
||
|
||
detections.append({
|
||
'bbox': (x, y, w, h),
|
||
'class_id': cls_id,
|
||
'confidence': float(conf)
|
||
})
|
||
# 处理其他可能的输出格式
|
||
else:
|
||
# 这里需要根据模型的实际输出格式进行适配
|
||
print("警告:无法识别的模型输出格式,请检查模型类型")
|
||
|
||
return detections
|
||
|
||
except Exception as e:
|
||
print(f"推理过程中出错: {str(e)}")
|
||
# 出错时返回空结果
|
||
return []
|
||
|
||
@property
|
||
def applies_nms(self) -> bool:
|
||
"""模型是否内部应用了 NMS"""
|
||
# YOLOv5 会自动应用 NMS
|
||
return True
|
||
|
||
def close(self):
|
||
"""释放资源"""
|
||
if hasattr(self, 'model'):
|
||
# 删除模型以释放 GPU 内存
|
||
del self.model
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
print("垃圾识别模型已关闭") |