56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
def attempt_load(weights, device=''):
|
|||
|
|
"""尝试加载YOLOv5模型"""
|
|||
|
|
# 加载模型
|
|||
|
|
model = torch.load(weights, map_location=device)
|
|||
|
|
|
|||
|
|
# 确定模型格式
|
|||
|
|
if isinstance(model, dict):
|
|||
|
|
if 'model' in model: # state_dict格式
|
|||
|
|
model = model['model']
|
|||
|
|
elif 'state_dict' in model: # state_dict格式
|
|||
|
|
model = model['state_dict']
|
|||
|
|
|
|||
|
|
# 如果是state_dict,则需要创建模型架构
|
|||
|
|
if isinstance(model, dict):
|
|||
|
|
print("警告:加载的是权重字典,尝试创建默认模型结构")
|
|||
|
|
from models.yolov5_model import YOLOv5
|
|||
|
|
model_arch = YOLOv5()
|
|||
|
|
model_arch.load_state_dict(model)
|
|||
|
|
model = model_arch
|
|||
|
|
|
|||
|
|
# 设置为评估模式
|
|||
|
|
if isinstance(model, nn.Module):
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
# 检查是否有类别信息
|
|||
|
|
if not hasattr(model, 'names') or not model.names:
|
|||
|
|
print("模型没有类别信息,尝试加载默认类别")
|
|||
|
|
# 设置通用类别
|
|||
|
|
model.names = ['object']
|
|||
|
|
|
|||
|
|
return model
|
|||
|
|
|
|||
|
|
class YOLOv5:
|
|||
|
|
"""简化版YOLOv5模型结构,用于加载权重"""
|
|||
|
|
def __init__(self):
|
|||
|
|
super(YOLOv5, self).__init__()
|
|||
|
|
self.names = [] # 类别名称
|
|||
|
|
# 这里应该添加真实的网络结构
|
|||
|
|
# 但为了简单起见,我们只提供一个占位符
|
|||
|
|
# 在实际使用中,您应该实现完整的网络架构
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
# 这里应该是实际的前向传播逻辑
|
|||
|
|
# 这只是一个占位符
|
|||
|
|
raise NotImplementedError("这是一个占位符模型,请使用完整的YOLOv5模型实现")
|
|||
|
|
|
|||
|
|
def load_state_dict(self, state_dict):
|
|||
|
|
print("尝试加载模型权重")
|
|||
|
|
# 实际的权重加载逻辑
|
|||
|
|
# 这只是一个占位符
|
|||
|
|
return self
|