Files
rtsp-video-analysis-system/python-inference-service/models/garbage_model.py
2025-09-30 14:23:33 +08:00

207 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import cv2
from typing import List, Dict, Any
import torch
class Model:
"""
垃圾识别模型 - 直接加载 PyTorch 模型文件
"""
def __init__(self):
"""初始化模型"""
# 获取当前文件所在目录路径
model_dir = os.path.dirname(os.path.abspath(__file__))
# 模型文件路径
model_path = os.path.join(model_dir, "best.pt")
print(f"正在加载垃圾识别模型: {model_path}")
# 加载 PyTorch 模型
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {self.device}")
# 使用 YOLOv5 或通用方式加载模型
try:
# 尝试使用 YOLOv5 加载
import sys
sys.path.append(os.path.dirname(model_dir)) # 添加父目录到路径
try:
# 方法1: 如果安装了 YOLOv5
import yolov5
self.model = yolov5.load(model_path, device=self.device)
self.yolov5_api = True
print("使用 YOLOv5 包加载模型")
except (ImportError, ModuleNotFoundError):
# 方法2: 直接加载 YOLO 代码
from models.yolov5_utils import attempt_load
self.model = attempt_load(model_path, device=self.device)
self.yolov5_api = False
print("使用内置 YOLOv5 工具加载模型")
except Exception as e:
# 方法3: 通用 PyTorch 加载
print(f"YOLOv5 加载失败: {e}")
print("使用通用 PyTorch 加载")
self.model = torch.load(model_path, map_location=self.device)
if isinstance(self.model, dict) and 'model' in self.model:
self.model = self.model['model']
self.yolov5_api = False
# 如果是 ScriptModule设置为评估模式
if isinstance(self.model, torch.jit.ScriptModule):
self.model.eval()
elif hasattr(self.model, 'eval'):
self.model.eval()
# 加载类别名称
self.classes = []
classes_path = os.path.join(model_dir, "classes.txt")
if os.path.exists(classes_path):
with open(classes_path, 'r', encoding='utf-8') as f:
self.classes = [line.strip() for line in f.readlines() if line.strip()]
print(f"已加载 {len(self.classes)} 个类别")
else:
# 如果模型自带类别信息
if hasattr(self.model, 'names') and self.model.names:
self.classes = self.model.names
print(f"使用模型自带类别,共 {len(self.classes)} 个类别")
else:
print("未找到类别文件,将使用数字索引作为类别名")
# 设置识别参数
self.conf_threshold = 0.25 # 置信度阈值
self.img_size = 640 # 默认输入图像大小
print("垃圾识别模型加载完成")
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""预处理图像"""
# 如果是使用 YOLOv5 API不需要预处理
if hasattr(self, 'yolov5_api') and self.yolov5_api:
return image
# 默认预处理:调整大小并归一化
img = cv2.resize(image, (self.img_size, self.img_size))
# BGR 转 RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 归一化 [0, 255] -> [0, 1]
img = img / 255.0
# HWC -> CHW (高度,宽度,通道 -> 通道,高度,宽度)
img = img.transpose(2, 0, 1)
# 转为 torch tensor
img = torch.from_numpy(img).float()
# 添加批次维度
img = img.unsqueeze(0)
# 移至设备
img = img.to(self.device)
return img
def predict(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""模型推理"""
original_height, original_width = image.shape[:2]
try:
# 如果使用 YOLOv5 API
if hasattr(self, 'yolov5_api') and self.yolov5_api:
# YOLOv5 API 直接处理图像
results = self.model(image)
# 提取检测结果
predictions = results.pred[0] # 第一批次的预测
detections = []
for *xyxy, conf, cls_id in predictions.cpu().numpy():
x1, y1, x2, y2 = xyxy
# 转换为归一化坐标 (x, y, w, h)
x = x1 / original_width
y = y1 / original_height
w = (x2 - x1) / original_width
h = (y2 - y1) / original_height
# 整数类别 ID
cls_id = int(cls_id)
# 获取类别名称
class_name = f"cls{cls_id}"
if 0 <= cls_id < len(self.classes):
class_name = self.classes[cls_id]
# 添加检测结果
if conf >= self.conf_threshold:
detections.append({
'bbox': (x, y, w, h),
'class_id': cls_id,
'confidence': float(conf)
})
return detections
else:
# 通用 PyTorch 模型处理
# 预处理图像
img = self.preprocess(image)
# 推理
with torch.no_grad():
outputs = self.model(img)
# 后处理结果(这里需要根据模型输出格式调整)
detections = []
# 假设输出格式是 YOLO 风格:[batch_idx, x1, y1, x2, y2, conf, cls_id]
if isinstance(outputs, torch.Tensor) and outputs.dim() == 2 and outputs.size(1) >= 6:
for *xyxy, conf, cls_id in outputs.cpu().numpy():
if conf >= self.conf_threshold:
x1, y1, x2, y2 = xyxy
# 转换为归一化坐标 (x, y, w, h)
x = x1 / original_width
y = y1 / original_height
w = (x2 - x1) / original_width
h = (y2 - y1) / original_height
# 整数类别 ID
cls_id = int(cls_id)
detections.append({
'bbox': (x, y, w, h),
'class_id': cls_id,
'confidence': float(conf)
})
# 处理其他可能的输出格式
else:
# 这里需要根据模型的实际输出格式进行适配
print("警告:无法识别的模型输出格式,请检查模型类型")
return detections
except Exception as e:
print(f"推理过程中出错: {str(e)}")
# 出错时返回空结果
return []
@property
def applies_nms(self) -> bool:
"""模型是否内部应用了 NMS"""
# YOLOv5 会自动应用 NMS
return True
def close(self):
"""释放资源"""
if hasattr(self, 'model'):
# 删除模型以释放 GPU 内存
del self.model
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("垃圾识别模型已关闭")