import os import numpy as np import cv2 from typing import List, Dict, Any import torch class Model: """ 垃圾识别模型 - 直接加载 PyTorch 模型文件 """ def __init__(self): """初始化模型""" # 获取当前文件所在目录路径 model_dir = os.path.dirname(os.path.abspath(__file__)) # 模型文件路径 model_path = os.path.join(model_dir, "best.pt") print(f"正在加载垃圾识别模型: {model_path}") # 加载 PyTorch 模型 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") # 使用 YOLOv5 或通用方式加载模型 try: # 尝试使用 YOLOv5 加载 import sys sys.path.append(os.path.dirname(model_dir)) # 添加父目录到路径 try: # 方法1: 如果安装了 YOLOv5 import yolov5 self.model = yolov5.load(model_path, device=self.device) self.yolov5_api = True print("使用 YOLOv5 包加载模型") except (ImportError, ModuleNotFoundError): # 方法2: 直接加载 YOLO 代码 from models.yolov5_utils import attempt_load self.model = attempt_load(model_path, device=self.device) self.yolov5_api = False print("使用内置 YOLOv5 工具加载模型") except Exception as e: # 方法3: 通用 PyTorch 加载 print(f"YOLOv5 加载失败: {e}") print("使用通用 PyTorch 加载") self.model = torch.load(model_path, map_location=self.device) if isinstance(self.model, dict) and 'model' in self.model: self.model = self.model['model'] self.yolov5_api = False # 如果是 ScriptModule,设置为评估模式 if isinstance(self.model, torch.jit.ScriptModule): self.model.eval() elif hasattr(self.model, 'eval'): self.model.eval() # 加载类别名称 self.classes = [] classes_path = os.path.join(model_dir, "classes.txt") if os.path.exists(classes_path): with open(classes_path, 'r', encoding='utf-8') as f: self.classes = [line.strip() for line in f.readlines() if line.strip()] print(f"已加载 {len(self.classes)} 个类别") else: # 如果模型自带类别信息 if hasattr(self.model, 'names') and self.model.names: self.classes = self.model.names print(f"使用模型自带类别,共 {len(self.classes)} 个类别") else: print("未找到类别文件,将使用数字索引作为类别名") # 设置识别参数 self.conf_threshold = 0.25 # 置信度阈值 self.img_size = 640 # 默认输入图像大小 print("垃圾识别模型加载完成") def preprocess(self, image: np.ndarray) -> np.ndarray: """预处理图像""" # 如果是使用 YOLOv5 API,不需要预处理 if hasattr(self, 'yolov5_api') and self.yolov5_api: return image # 默认预处理:调整大小并归一化 img = cv2.resize(image, (self.img_size, self.img_size)) # BGR 转 RGB img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 归一化 [0, 255] -> [0, 1] img = img / 255.0 # HWC -> CHW (高度,宽度,通道 -> 通道,高度,宽度) img = img.transpose(2, 0, 1) # 转为 torch tensor img = torch.from_numpy(img).float() # 添加批次维度 img = img.unsqueeze(0) # 移至设备 img = img.to(self.device) return img def predict(self, image: np.ndarray) -> List[Dict[str, Any]]: """模型推理""" original_height, original_width = image.shape[:2] try: # 如果使用 YOLOv5 API if hasattr(self, 'yolov5_api') and self.yolov5_api: # YOLOv5 API 直接处理图像 results = self.model(image) # 提取检测结果 predictions = results.pred[0] # 第一批次的预测 detections = [] for *xyxy, conf, cls_id in predictions.cpu().numpy(): x1, y1, x2, y2 = xyxy # 转换为归一化坐标 (x, y, w, h) x = x1 / original_width y = y1 / original_height w = (x2 - x1) / original_width h = (y2 - y1) / original_height # 整数类别 ID cls_id = int(cls_id) # 获取类别名称 class_name = f"cls{cls_id}" if 0 <= cls_id < len(self.classes): class_name = self.classes[cls_id] # 添加检测结果 if conf >= self.conf_threshold: detections.append({ 'bbox': (x, y, w, h), 'class_id': cls_id, 'confidence': float(conf) }) return detections else: # 通用 PyTorch 模型处理 # 预处理图像 img = self.preprocess(image) # 推理 with torch.no_grad(): outputs = self.model(img) # 后处理结果(这里需要根据模型输出格式调整) detections = [] # 假设输出格式是 YOLO 风格:[batch_idx, x1, y1, x2, y2, conf, cls_id] if isinstance(outputs, torch.Tensor) and outputs.dim() == 2 and outputs.size(1) >= 6: for *xyxy, conf, cls_id in outputs.cpu().numpy(): if conf >= self.conf_threshold: x1, y1, x2, y2 = xyxy # 转换为归一化坐标 (x, y, w, h) x = x1 / original_width y = y1 / original_height w = (x2 - x1) / original_width h = (y2 - y1) / original_height # 整数类别 ID cls_id = int(cls_id) detections.append({ 'bbox': (x, y, w, h), 'class_id': cls_id, 'confidence': float(conf) }) # 处理其他可能的输出格式 else: # 这里需要根据模型的实际输出格式进行适配 print("警告:无法识别的模型输出格式,请检查模型类型") return detections except Exception as e: print(f"推理过程中出错: {str(e)}") # 出错时返回空结果 return [] @property def applies_nms(self) -> bool: """模型是否内部应用了 NMS""" # YOLOv5 会自动应用 NMS return True def close(self): """释放资源""" if hasattr(self, 'model'): # 删除模型以释放 GPU 内存 del self.model if torch.cuda.is_available(): torch.cuda.empty_cache() print("垃圾识别模型已关闭")