From aa32f9e9ac9c4dd44e1eb9f5bb67854f4e546677 Mon Sep 17 00:00:00 2001 From: Joshi <3040996759@qq.com> Date: Tue, 7 Oct 2025 17:53:34 +0800 Subject: [PATCH] =?UTF-8?q?fix(models):=20=E8=A7=A3=E5=86=B3=20PyTorch=202?= =?UTF-8?q?.6+=20=E5=85=BC=E5=AE=B9=E6=80=A7=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 garbage_model.py 和 smoke_model.py 中添加 weights_only=False 参数以允许加载模型类结构 - 修复 HTTP YOLO 检测器中的文件上传和响应解析逻辑- 移除不必要的导入并优化代码结构 - 添加自定义字节数组资源类以支持 RestTemplate 文件上传- 改进错误处理和日志记录机制 --- .../models/garbage_model.py | 6 +- .../models/smoke_model.py | 6 +- .../thread/detector/HttpYoloDetector.java | 130 +++++++++++++----- 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/python-inference-service/models/garbage_model.py b/python-inference-service/models/garbage_model.py index 1c7f6ea..b53dc81 100644 --- a/python-inference-service/models/garbage_model.py +++ b/python-inference-service/models/garbage_model.py @@ -45,7 +45,11 @@ class Model: # 方法3: 通用 PyTorch 加载 print(f"YOLOv5 加载失败: {e}") print("使用通用 PyTorch 加载") - self.model = torch.load(model_path, map_location=self.device) + self.model = torch.load( + model_path, + map_location=self.device, + weights_only=False # 允许加载模型类结构(解决 PyTorch 2.6+ 兼容性问题) + ) if isinstance(self.model, dict) and 'model' in self.model: self.model = self.model['model'] self.yolov5_api = False diff --git a/python-inference-service/models/smoke_model.py b/python-inference-service/models/smoke_model.py index 9ab6639..a61db2a 100644 --- a/python-inference-service/models/smoke_model.py +++ b/python-inference-service/models/smoke_model.py @@ -45,7 +45,11 @@ class Model: # 方法3: 通用 PyTorch 加载 print(f"YOLOv5 加载失败: {e}") print("使用通用 PyTorch 加载") - self.model = torch.load(model_path, map_location=self.device) + self.model = torch.load( + model_path, + map_location=self.device, + weights_only=False # 允许加载模型类结构,解决 PyTorch 2.6+ 兼容性问题 + ) if isinstance(self.model, dict) and 'model' in self.model: self.model = self.model['model'] self.yolov5_api = False diff --git a/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/HttpYoloDetector.java b/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/HttpYoloDetector.java index 4faa059..c3d49bf 100644 --- a/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/HttpYoloDetector.java +++ b/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/HttpYoloDetector.java @@ -2,12 +2,10 @@ package com.ruoyi.video.thread.detector; import com.fasterxml.jackson.databind.ObjectMapper; import com.ruoyi.video.domain.Detection; -import org.bytedeco.opencv.opencv_core.Mat; -import org.bytedeco.opencv.opencv_core.Rect; +import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.javacpp.BytePointer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.core.io.ByteArrayResource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -15,12 +13,11 @@ import org.springframework.http.ResponseEntity; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.RestTemplate; +import org.springframework.web.multipart.MultipartFile; +import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.imencode; @@ -56,9 +53,11 @@ public class HttpYoloDetector implements YoloDetector { log.info("创建HTTP YOLOv8检测器: {}, 服务地址: {}, 模型: {}", name, apiUrl, modelName); } + @Override public String name() { return name; } + @Override public List detect(Mat bgr) { if (bgr == null || bgr.empty()) { @@ -69,22 +68,17 @@ public class HttpYoloDetector implements YoloDetector { // 将OpenCV的Mat转换为JPEG字节数组 BytePointer buffer = new BytePointer(); imencode(".jpg", bgr, buffer); - byte[] jpgBytes = new byte[(int) (buffer.capacity())]; + byte[] jpgBytes = new byte[(int)(buffer.capacity())]; buffer.get(jpgBytes); buffer.deallocate(); - // 准备HTTP请求(multipart) + // 准备HTTP请求参数 HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.MULTIPART_FORM_DATA); MultiValueMap body = new LinkedMultiValueMap<>(); - ByteArrayResource fileRes = new ByteArrayResource(jpgBytes) { - @Override - public String getFilename() { - return "image.jpg"; - } - }; - body.add("file", fileRes); + // 仅发送文件,model_name 放到查询参数 + body.add("file", new CustomByteArrayResource(jpgBytes, "image.jpg")); HttpEntity> requestEntity = new HttpEntity<>(body, headers); @@ -97,36 +91,98 @@ public class HttpYoloDetector implements YoloDetector { urlWithQuery = apiUrl + (apiUrl.contains("?") ? "&" : "?") + "model_name=" + modelName; } - // 执行请求 + // 发送请求到Python服务 ResponseEntity response = restTemplate.postForEntity(urlWithQuery, requestEntity, String.class); String responseBody = response.getBody(); - if (responseBody == null || responseBody.isEmpty()) { + if (!response.getStatusCode().is2xxSuccessful()) { + log.error("HTTP检测失败: status={}, body={}", response.getStatusCodeValue(), responseBody); return Collections.emptyList(); } - // 解析响应 - Map result = objectMapper.readValue(responseBody, Map.class); - Object detsObj = result.get("detections"); - if (!(detsObj instanceof List)) { - return Collections.emptyList(); - } + if (responseBody != null) { + // 解析响应JSON + Map result = objectMapper.readValue(responseBody, Map.class); + List> detectionsJson = (List>) result.get("detections"); - List> detectionsJson = (List>) detsObj; - List detections = new ArrayList<>(); - for (Map det : detectionsJson) { - String label = (String) det.getOrDefault("label", ""); - double confidence = det.get("confidence") == null ? 0.0 : ((Number) det.get("confidence")).doubleValue(); - int x = det.get("x") == null ? 0 : ((Number) det.get("x")).intValue(); - int y = det.get("y") == null ? 0 : ((Number) det.get("y")).intValue(); - int width = det.get("width") == null ? 0 : ((Number) det.get("width")).intValue(); - int height = det.get("height") == null ? 0 : ((Number) det.get("height")).intValue(); - detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR)); - } + List detections = new ArrayList<>(); + for (Map det : detectionsJson) { + String label = (String) det.get("label"); + double confidence = ((Number) det.get("confidence")).doubleValue(); + int x = ((Number) det.get("x")).intValue(); + int y = ((Number) det.get("y")).intValue(); + int width = ((Number) det.get("width")).intValue(); + int height = ((Number) det.get("height")).intValue(); - return detections; + detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR)); + } + + return detections; + } } catch (Exception e) { log.error("HTTP检测请求失败: {}", e.getMessage()); - return Collections.emptyList(); + } + + return Collections.emptyList(); + } + + // 用于RestTemplate的字节数组资源类 + private static class CustomByteArrayResource implements org.springframework.core.io.Resource { + private final byte[] byteArray; + private final String filename; + + public CustomByteArrayResource(byte[] byteArray, String filename) { + this.byteArray = byteArray; + this.filename = filename; + } + + @Override + public String getFilename() { + return this.filename; + } + + @Override + public java.io.InputStream getInputStream() throws IOException { + return new java.io.ByteArrayInputStream(this.byteArray); + } + + @Override + public boolean exists() { + return true; + } + + @Override + public java.net.URL getURL() throws IOException { + throw new IOException("Not supported"); + } + + @Override + public java.net.URI getURI() throws IOException { + throw new IOException("Not supported"); + } + + @Override + public java.io.File getFile() throws IOException { + throw new IOException("Not supported"); + } + + @Override + public long contentLength() { + return this.byteArray.length; + } + + @Override + public long lastModified() { + return System.currentTimeMillis(); + } + + @Override + public org.springframework.core.io.Resource createRelative(String relativePath) throws IOException { + throw new IOException("Not supported"); + } + + @Override + public String getDescription() { + return "Byte array resource [" + this.filename + "]"; } } } \ No newline at end of file