Files
rtsp-video-analysis-system/python-inference-service/models/yolov5_utils.py
2025-09-30 14:23:33 +08:00

56 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import sys
import os
def attempt_load(weights, device=''):
"""尝试加载YOLOv5模型"""
# 加载模型
model = torch.load(weights, map_location=device)
# 确定模型格式
if isinstance(model, dict):
if 'model' in model: # state_dict格式
model = model['model']
elif 'state_dict' in model: # state_dict格式
model = model['state_dict']
# 如果是state_dict则需要创建模型架构
if isinstance(model, dict):
print("警告:加载的是权重字典,尝试创建默认模型结构")
from models.yolov5_model import YOLOv5
model_arch = YOLOv5()
model_arch.load_state_dict(model)
model = model_arch
# 设置为评估模式
if isinstance(model, nn.Module):
model.eval()
# 检查是否有类别信息
if not hasattr(model, 'names') or not model.names:
print("模型没有类别信息,尝试加载默认类别")
# 设置通用类别
model.names = ['object']
return model
class YOLOv5:
"""简化版YOLOv5模型结构用于加载权重"""
def __init__(self):
super(YOLOv5, self).__init__()
self.names = [] # 类别名称
# 这里应该添加真实的网络结构
# 但为了简单起见,我们只提供一个占位符
# 在实际使用中,您应该实现完整的网络架构
def forward(self, x):
# 这里应该是实际的前向传播逻辑
# 这只是一个占位符
raise NotImplementedError("这是一个占位符模型请使用完整的YOLOv5模型实现")
def load_state_dict(self, state_dict):
print("尝试加载模型权重")
# 实际的权重加载逻辑
# 这只是一个占位符
return self