refactor(detector):重构HTTP YOLO检测器实现

- 使用ByteArrayResource替代自定义资源类
- 将model_name参数移至URL查询参数
-优化响应解析逻辑,增强类型检查
- 改进错误处理和空值判断
- 清理无用的导入和代码格式化- 修复潜在的编码异常处理问题
This commit is contained in:
2025-10-07 16:57:03 +08:00
parent e3701991ef
commit 5f6058c024

View File

@@ -2,10 +2,12 @@ package com.ruoyi.video.thread.detector;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ruoyi.video.domain.Detection;
import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.javacpp.BytePointer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
@@ -13,11 +15,12 @@ import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.bytedeco.opencv.global.opencv_imgcodecs.imencode;
@@ -27,14 +30,14 @@ import static org.bytedeco.opencv.global.opencv_imgcodecs.imencode;
*/
public class HttpYoloDetector implements YoloDetector {
private static final Logger log = LoggerFactory.getLogger(HttpYoloDetector.class);
private final String name;
private final String apiUrl;
private final String modelName;
private final int colorBGR;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;
/**
* 创建HTTP检测器
* @param name 检测器名称
@@ -49,127 +52,81 @@ public class HttpYoloDetector implements YoloDetector {
this.colorBGR = colorBGR;
this.restTemplate = new RestTemplate();
this.objectMapper = new ObjectMapper();
log.info("创建HTTP YOLOv8检测器: {}, 服务地址: {}, 模型: {}", name, apiUrl, modelName);
}
@Override
public String name() {
return name;
}
@Override
public List<Detection> detect(Mat bgr) {
if (bgr == null || bgr.empty()) {
return Collections.emptyList();
}
try {
// 将OpenCV的Mat转换为JPEG字节数组
BytePointer buffer = new BytePointer();
imencode(".jpg", bgr, buffer);
byte[] jpgBytes = new byte[(int)(buffer.capacity())];
byte[] jpgBytes = new byte[(int) (buffer.capacity())];
buffer.get(jpgBytes);
buffer.deallocate();
// 准备HTTP请求参数
// 准备HTTP请求multipart
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
body.add("model_name", modelName);
body.add("file", new CustomByteArrayResource(jpgBytes, "image.jpg"));
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
// 发送请求到Python服务
ResponseEntity<String> response = restTemplate.postForEntity(apiUrl, requestEntity, String.class);
String responseBody = response.getBody();
if (responseBody != null) {
// 解析响应JSON
Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
List<Map<String, Object>> detectionsJson = (List<Map<String, Object>>) result.get("detections");
List<Detection> detections = new ArrayList<>();
for (Map<String, Object> det : detectionsJson) {
String label = (String) det.get("label");
double confidence = ((Number) det.get("confidence")).doubleValue();
int x = ((Number) det.get("x")).intValue();
int y = ((Number) det.get("y")).intValue();
int width = ((Number) det.get("width")).intValue();
int height = ((Number) det.get("height")).intValue();
detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR));
ByteArrayResource fileRes = new ByteArrayResource(jpgBytes) {
@Override
public String getFilename() {
return "image.jpg";
}
return detections;
};
body.add("file", fileRes);
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
// 将 model_name 作为查询参数
String urlWithQuery;
try {
String encoded = java.net.URLEncoder.encode(modelName, java.nio.charset.StandardCharsets.UTF_8.toString());
urlWithQuery = apiUrl + (apiUrl.contains("?") ? "&" : "?") + "model_name=" + encoded;
} catch (Exception ex) {
urlWithQuery = apiUrl + (apiUrl.contains("?") ? "&" : "?") + "model_name=" + modelName;
}
// 执行请求
ResponseEntity<String> response = restTemplate.postForEntity(urlWithQuery, requestEntity, String.class);
String responseBody = response.getBody();
if (responseBody == null || responseBody.isEmpty()) {
return Collections.emptyList();
}
// 解析响应
Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
Object detsObj = result.get("detections");
if (!(detsObj instanceof List)) {
return Collections.emptyList();
}
List<Map<String, Object>> detectionsJson = (List<Map<String, Object>>) detsObj;
List<Detection> detections = new ArrayList<>();
for (Map<String, Object> det : detectionsJson) {
String label = (String) det.getOrDefault("label", "");
double confidence = det.get("confidence") == null ? 0.0 : ((Number) det.get("confidence")).doubleValue();
int x = det.get("x") == null ? 0 : ((Number) det.get("x")).intValue();
int y = det.get("y") == null ? 0 : ((Number) det.get("y")).intValue();
int width = det.get("width") == null ? 0 : ((Number) det.get("width")).intValue();
int height = det.get("height") == null ? 0 : ((Number) det.get("height")).intValue();
detections.add(new Detection(label, confidence, new Rect(x, y, width, height), colorBGR));
}
return detections;
} catch (Exception e) {
log.error("HTTP检测请求失败: {}", e.getMessage());
}
return Collections.emptyList();
}
// 用于RestTemplate的字节数组资源类
private static class CustomByteArrayResource implements org.springframework.core.io.Resource {
private final byte[] byteArray;
private final String filename;
public CustomByteArrayResource(byte[] byteArray, String filename) {
this.byteArray = byteArray;
this.filename = filename;
}
@Override
public String getFilename() {
return this.filename;
}
@Override
public java.io.InputStream getInputStream() throws IOException {
return new java.io.ByteArrayInputStream(this.byteArray);
}
@Override
public boolean exists() {
return true;
}
@Override
public java.net.URL getURL() throws IOException {
throw new IOException("Not supported");
}
@Override
public java.net.URI getURI() throws IOException {
throw new IOException("Not supported");
}
@Override
public java.io.File getFile() throws IOException {
throw new IOException("Not supported");
}
@Override
public long contentLength() {
return this.byteArray.length;
}
@Override
public long lastModified() {
return System.currentTimeMillis();
}
@Override
public org.springframework.core.io.Resource createRelative(String relativePath) throws IOException {
throw new IOException("Not supported");
}
@Override
public String getDescription() {
return "Byte array resource [" + this.filename + "]";
return Collections.emptyList();
}
}
}