fix(models): 解决 PyTorch 2.6+ 兼容性问题
- 在 garbage_model.py 和 smoke_model.py 中添加 weights_only=False 参数以允许加载模型类结构 - 修复 HTTP YOLO 检测器中的文件上传和响应解析逻辑- 移除不必要的导入并优化代码结构 - 添加自定义字节数组资源类以支持 RestTemplate 文件上传- 改进错误处理和日志记录机制
This commit is contained in:
@@ -45,7 +45,11 @@ class Model:
|
|||||||
# 方法3: 通用 PyTorch 加载
|
# 方法3: 通用 PyTorch 加载
|
||||||
print(f"YOLOv5 加载失败: {e}")
|
print(f"YOLOv5 加载失败: {e}")
|
||||||
print("使用通用 PyTorch 加载")
|
print("使用通用 PyTorch 加载")
|
||||||
self.model = torch.load(model_path, map_location=self.device)
|
self.model = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=self.device,
|
||||||
|
weights_only=False # 允许加载模型类结构(解决 PyTorch 2.6+ 兼容性问题)
|
||||||
|
)
|
||||||
if isinstance(self.model, dict) and 'model' in self.model:
|
if isinstance(self.model, dict) and 'model' in self.model:
|
||||||
self.model = self.model['model']
|
self.model = self.model['model']
|
||||||
self.yolov5_api = False
|
self.yolov5_api = False
|
||||||
|
|||||||
@@ -45,7 +45,11 @@ class Model:
|
|||||||
# 方法3: 通用 PyTorch 加载
|
# 方法3: 通用 PyTorch 加载
|
||||||
print(f"YOLOv5 加载失败: {e}")
|
print(f"YOLOv5 加载失败: {e}")
|
||||||
print("使用通用 PyTorch 加载")
|
print("使用通用 PyTorch 加载")
|
||||||
self.model = torch.load(model_path, map_location=self.device)
|
self.model = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=self.device,
|
||||||
|
weights_only=False # 允许加载模型类结构,解决 PyTorch 2.6+ 兼容性问题
|
||||||
|
)
|
||||||
if isinstance(self.model, dict) and 'model' in self.model:
|
if isinstance(self.model, dict) and 'model' in self.model:
|
||||||
self.model = self.model['model']
|
self.model = self.model['model']
|
||||||
self.yolov5_api = False
|
self.yolov5_api = False
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ package com.ruoyi.video.thread.detector;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.ruoyi.video.domain.Detection;
|
import com.ruoyi.video.domain.Detection;
|
||||||
import org.bytedeco.opencv.opencv_core.Mat;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
import org.bytedeco.opencv.opencv_core.Rect;
|
|
||||||
import org.bytedeco.javacpp.BytePointer;
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.core.io.ByteArrayResource;
|
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
@@ -15,12 +13,11 @@ import org.springframework.http.ResponseEntity;
|
|||||||
import org.springframework.util.LinkedMultiValueMap;
|
import org.springframework.util.LinkedMultiValueMap;
|
||||||
import org.springframework.util.MultiValueMap;
|
import org.springframework.util.MultiValueMap;
|
||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
|
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.bytedeco.opencv.global.opencv_imgcodecs.imencode;
|
import static org.bytedeco.opencv.global.opencv_imgcodecs.imencode;
|
||||||
|
|
||||||
@@ -56,9 +53,11 @@ public class HttpYoloDetector implements YoloDetector {
|
|||||||
log.info("创建HTTP YOLOv8检测器: {}, 服务地址: {}, 模型: {}", name, apiUrl, modelName);
|
log.info("创建HTTP YOLOv8检测器: {}, 服务地址: {}, 模型: {}", name, apiUrl, modelName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public String name() {
|
public String name() {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Detection> detect(Mat bgr) {
|
public List<Detection> detect(Mat bgr) {
|
||||||
if (bgr == null || bgr.empty()) {
|
if (bgr == null || bgr.empty()) {
|
||||||
@@ -73,18 +72,13 @@ public class HttpYoloDetector implements YoloDetector {
|
|||||||
buffer.get(jpgBytes);
|
buffer.get(jpgBytes);
|
||||||
buffer.deallocate();
|
buffer.deallocate();
|
||||||
|
|
||||||
// 准备HTTP请求(multipart)
|
// 准备HTTP请求参数
|
||||||
HttpHeaders headers = new HttpHeaders();
|
HttpHeaders headers = new HttpHeaders();
|
||||||
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
|
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||||
|
|
||||||
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
|
||||||
ByteArrayResource fileRes = new ByteArrayResource(jpgBytes) {
|
// 仅发送文件,model_name 放到查询参数
|
||||||
@Override
|
body.add("file", new CustomByteArrayResource(jpgBytes, "image.jpg"));
|
||||||
public String getFilename() {
|
|
||||||
return "image.jpg";
|
|
||||||
}
|
|
||||||
};
|
|
||||||
body.add("file", fileRes);
|
|
||||||
|
|
||||||
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
|
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
|
||||||
|
|
||||||
@@ -97,36 +91,98 @@ public class HttpYoloDetector implements YoloDetector {
|
|||||||
urlWithQuery = apiUrl + (apiUrl.contains("?") ? "&" : "?") + "model_name=" + modelName;
|
urlWithQuery = apiUrl + (apiUrl.contains("?") ? "&" : "?") + "model_name=" + modelName;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行请求
|
// 发送请求到Python服务
|
||||||
ResponseEntity<String> response = restTemplate.postForEntity(urlWithQuery, requestEntity, String.class);
|
ResponseEntity<String> response = restTemplate.postForEntity(urlWithQuery, requestEntity, String.class);
|
||||||
String responseBody = response.getBody();
|
String responseBody = response.getBody();
|
||||||
if (responseBody == null || responseBody.isEmpty()) {
|
if (!response.getStatusCode().is2xxSuccessful()) {
|
||||||
|
log.error("HTTP检测失败: status={}, body={}", response.getStatusCodeValue(), responseBody);
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析响应
|
if (responseBody != null) {
|
||||||
|
// 解析响应JSON
|
||||||
Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
|
Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
|
||||||
Object detsObj = result.get("detections");
|
List<Map<String, Object>> detectionsJson = (List<Map<String, Object>>) result.get("detections");
|
||||||
if (!(detsObj instanceof List)) {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
List<Map<String, Object>> detectionsJson = (List<Map<String, Object>>) detsObj;
|
|
||||||
List<Detection> detections = new ArrayList<>();
|
List<Detection> detections = new ArrayList<>();
|
||||||
for (Map<String, Object> det : detectionsJson) {
|
for (Map<String, Object> det : detectionsJson) {
|
||||||
String label = (String) det.getOrDefault("label", "");
|
String label = (String) det.get("label");
|
||||||
double confidence = det.get("confidence") == null ? 0.0 : ((Number) det.get("confidence")).doubleValue();
|
double confidence = ((Number) det.get("confidence")).doubleValue();
|
||||||
int x = det.get("x") == null ? 0 : ((Number) det.get("x")).intValue();
|
int x = ((Number) det.get("x")).intValue();
|
||||||
int y = det.get("y") == null ? 0 : ((Number) det.get("y")).intValue();
|
int y = ((Number) det.get("y")).intValue();
|
||||||
int width = det.get("width") == null ? 0 : ((Number) det.get("width")).intValue();
|
int width = ((Number) det.get("width")).intValue();
|
||||||
int height = det.get("height") == null ? 0 : ((Number) det.get("height")).intValue();
|
int height = ((Number) det.get("height")).intValue();
|
||||||
|
|
||||||
detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR));
|
detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR));
|
||||||
}
|
}
|
||||||
|
|
||||||
return detections;
|
return detections;
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("HTTP检测请求失败: {}", e.getMessage());
|
log.error("HTTP检测请求失败: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 用于RestTemplate的字节数组资源类
|
||||||
|
private static class CustomByteArrayResource implements org.springframework.core.io.Resource {
|
||||||
|
private final byte[] byteArray;
|
||||||
|
private final String filename;
|
||||||
|
|
||||||
|
public CustomByteArrayResource(byte[] byteArray, String filename) {
|
||||||
|
this.byteArray = byteArray;
|
||||||
|
this.filename = filename;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getFilename() {
|
||||||
|
return this.filename;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public java.io.InputStream getInputStream() throws IOException {
|
||||||
|
return new java.io.ByteArrayInputStream(this.byteArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean exists() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public java.net.URL getURL() throws IOException {
|
||||||
|
throw new IOException("Not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public java.net.URI getURI() throws IOException {
|
||||||
|
throw new IOException("Not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public java.io.File getFile() throws IOException {
|
||||||
|
throw new IOException("Not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long contentLength() {
|
||||||
|
return this.byteArray.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long lastModified() {
|
||||||
|
return System.currentTimeMillis();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public org.springframework.core.io.Resource createRelative(String relativePath) throws IOException {
|
||||||
|
throw new IOException("Not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getDescription() {
|
||||||
|
return "Byte array resource [" + this.filename + "]";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user