diff --git a/ruoyi-video/src/main/java/com/ruoyi/video/common/ModelManager.java b/ruoyi-video/src/main/java/com/ruoyi/video/common/ModelManager.java index a9cc59a..54cc4cd 100644 --- a/ruoyi-video/src/main/java/com/ruoyi/video/common/ModelManager.java +++ b/ruoyi-video/src/main/java/com/ruoyi/video/common/ModelManager.java @@ -32,8 +32,8 @@ public final class ModelManager implements AutoCloseable { int rgb = palette[i % palette.length]; i++; int bgr = ((rgb & 0xFF) << 16) | (rgb & 0xFF00) | ((rgb >> 16) & 0xFF); - // 使用OpenVinoYoloDetector,但强制使用OpenCV后端 - YoloDetector det = new OpenVinoYoloDetector(name, dir, w, h, "opencv", bgr); + // 使用OnnxYoloDetector替代OpenVinoYoloDetector + YoloDetector det = new OnnxYoloDetector(name, dir, w, h, backend, bgr); map.put(name, det); } } @@ -45,4 +45,4 @@ public final class ModelManager implements AutoCloseable { map.values().forEach(d -> { try { d.close(); } catch(Exception ignored){} }); map.clear(); } -} +} \ No newline at end of file diff --git a/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/OnnxYoloDetector.java b/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/OnnxYoloDetector.java new file mode 100644 index 0000000..44156ed --- /dev/null +++ b/ruoyi-video/src/main/java/com/ruoyi/video/thread/detector/OnnxYoloDetector.java @@ -0,0 +1,249 @@ +package com.ruoyi.video.thread.detector; + +import com.ruoyi.video.domain.Detection; +import org.bytedeco.javacpp.indexer.FloatRawIndexer; +import org.bytedeco.opencv.opencv_core.*; +import org.bytedeco.opencv.opencv_dnn.Net; + +import java.nio.file.*; +import java.util.*; + +import static org.bytedeco.opencv.global.opencv_dnn.*; +import static org.bytedeco.opencv.global.opencv_core.*; +import static org.bytedeco.opencv.global.opencv_imgproc.*; + +public final class OnnxYoloDetector implements YoloDetector { + private final String modelName; + private final Net net; + private final Size input; + private final float confTh = 0.25f, nmsTh = 0.45f; + private final String[] classes; + private final int colorBGR; + + public OnnxYoloDetector(String name, Path dir, int inW, int inH, String backend, int colorBGR) throws Exception { + this.modelName = name; + this.input = new Size(inW, inH); + this.colorBGR = colorBGR; + + // 查找ONNX模型文件 + String onnx = findModelFile(dir, ".onnx"); + if (onnx == null) { + throw new Exception("找不到ONNX模型文件,请确保目录中存在 .onnx 文件: " + dir); + } + + // 读取类别文件 + Path clsPath = dir.resolve("classes.txt"); + if (Files.exists(clsPath)) { + this.classes = Files.readAllLines(clsPath).stream().map(String::trim) + .filter(s -> !s.isEmpty()).toArray(String[]::new); + } else { + this.classes = new String[0]; + } + + try { + // 加载ONNX模型 + this.net = readNetFromONNX(onnx); + + // 设置OpenCV后端 + net.setPreferableBackend(DNN_BACKEND_OPENCV); + net.setPreferableTarget(DNN_TARGET_CPU); + + System.out.println("ONNX模型加载成功: " + name + " (" + onnx + ")"); + + } catch (Exception e) { + throw new Exception("模型加载失败: " + e.getMessage() + + "\n请确保ONNX模型文件格式正确", e); + } + } + + /** + * 在目录中查找指定扩展名的模型文件 + */ + private String findModelFile(Path dir, String extension) { + try { + return Files.list(dir) + .filter(path -> path.toString().toLowerCase().endsWith(extension.toLowerCase())) + .map(Path::toString) + .findFirst() + .orElse(null); + } catch (Exception e) { + return null; + } + } + + @Override public String name() { return modelName; } + + @Override + public List detect(Mat bgr) { + if (bgr == null || bgr.empty()) return Collections.emptyList(); + + // 统一成 BGR 3 通道,避免 blobFromImage 断言失败 + if (bgr.channels() != 3) { + Mat tmp = new Mat(); + if (bgr.channels() == 1) cvtColor(bgr, tmp, COLOR_GRAY2BGR); + else if (bgr.channels() == 4) cvtColor(bgr, tmp, COLOR_BGRA2BGR); + else bgr.copyTo(tmp); + bgr = tmp; + } + + try (Mat blob = blobFromImage(bgr, 1.0/255.0, input, new Scalar(0.0), true, false, CV_32F)) { + net.setInput(blob); + // ===== 多输出兼容(Bytedeco 正确写法)===== + org.bytedeco.opencv.opencv_core.StringVector outNames = net.getUnconnectedOutLayersNames(); + List outs = new ArrayList<>(); + + if (outNames == null || outNames.size() == 0) { + // 只有一个默认输出 + Mat out = net.forward(); // ← 直接返回 Mat + outs.add(out); + } else { + // 多输出:用 MatVector 承接 + org.bytedeco.opencv.opencv_core.MatVector outBlobs = + new org.bytedeco.opencv.opencv_core.MatVector(outNames.size()); + net.forward(outBlobs, outNames); // ← 正确的重载 + + for (long i = 0; i < outBlobs.size(); i++) { + outs.add(outBlobs.get(i)); + } + } + + int fw = bgr.cols(), fh = bgr.rows(); + List boxes = new ArrayList<>(); + List scores = new ArrayList<>(); + List classIds = new ArrayList<>(); + + for (Mat out : outs) { + parseYoloOutput(out, fw, fh, boxes, scores, classIds); + } + if (boxes.isEmpty()) return Collections.emptyList(); + + // 纯 Java NMS,避免 MatOf* / Vector API 兼容问题 + List keep = nmsIndices(boxes, scores, nmsTh); + + List result = new ArrayList<>(keep.size()); + for (int k : keep) { + Rect2d r = boxes.get(k); + Rect rect = new Rect((int)r.x(), (int)r.y(), (int)r.width(), (int)r.height()); + int cid = classIds.get(k); + String cname = (cid >= 0 && cid < classes.length) ? classes[cid] : ("cls"+cid); + result.add(new Detection("["+modelName+"] "+cname, scores.get(k), rect, colorBGR)); + } + return result; + } catch (Throwable e) { + // 单帧失败不影响整体 + return Collections.emptyList(); + } + } + + /** 解析 YOLO-IR 输出为 N×C(C>=6),并填充 boxes/scores/classIds。 */ + private void parseYoloOutput(Mat out, int fw, int fh, + List boxes, List scores, List classIds) { + int dims = out.dims(); + Mat m; + + if (dims == 2) { + // NxC 或 CxN + if (out.cols() >= 6) { + m = out; + } else { + Mat tmp = new Mat(); + transpose(out, tmp); // CxN -> NxC + m = tmp; + } + } else if (dims == 3) { + // [1,N,C] 或 [1,C,N] + if (out.size(2) >= 6) { + m = out.reshape(1, out.size(1)); // -> N×C + } else { + Mat squeezed = out.reshape(1, out.size(1)); // C×N + Mat tmp = new Mat(); + transpose(squeezed, tmp); // -> N×C + m = tmp; + } + } else if (dims == 4) { + // [1,1,N,C] 或 [1,1,C,N] + int a = out.size(2), b = out.size(3); + if (b >= 6) { + m = out.reshape(1, a).clone(); // -> N×C + } else { + Mat cxn = out.reshape(1, b); // C×N + Mat tmp = new Mat(); + transpose(cxn, tmp); // -> N×C + m = tmp.clone(); + } + } else { + return; // 不支持的形状 + } + + int N = m.rows(), C = m.cols(); + if (C < 6 || N <= 0) return; + + FloatRawIndexer idx = m.createIndexer(); + for (int i = 0; i < N; i++) { + float cx = idx.get(i,0), cy = idx.get(i,1), w = idx.get(i,2), h = idx.get(i,3); + float obj = idx.get(i,4); + + int bestCls = -1; float bestScore = 0f; + for (int c = 5; c < C; c++) { + float p = idx.get(i,c); + if (p > bestScore) { bestScore = p; bestCls = c - 5; } + } + float conf = obj * bestScore; + if (conf < confTh) continue; + + // 默认假设归一化中心点格式 (cx,cy,w,h);若你的 IR 是 x1,y1,x2,y2,请把这里换算改掉 + int bx = Math.max(0, Math.round(cx * fw - (w * fw) / 2f)); + int by = Math.max(0, Math.round(cy * fh - (h * fh) / 2f)); + int bw = Math.min(fw - bx, Math.round(w * fw)); + int bh = Math.min(fh - by, Math.round(h * fh)); + if (bw <= 0 || bh <= 0) continue; + + boxes.add(new Rect2d(bx, by, bw, bh)); + scores.add(conf); + classIds.add(bestCls); + } + } + + /** 纯 Java NMS(IoU 抑制),返回保留的下标列表。 */ + private List nmsIndices(List boxes, List scores, float nmsThreshold) { + List order = new ArrayList<>(boxes.size()); + for (int i = 0; i < boxes.size(); i++) order.add(i); + // 按分数降序 + order.sort((i, j) -> Float.compare(scores.get(j), scores.get(i))); + + List keep = new ArrayList<>(); + boolean[] removed = new boolean[boxes.size()]; + + for (int a = 0; a < order.size(); a++) { + int i = order.get(a); + if (removed[i]) continue; + keep.add(i); + + Rect2d bi = boxes.get(i); + double areaI = bi.width() * bi.height(); + + for (int b = a + 1; b < order.size(); b++) { + int j = order.get(b); + if (removed[j]) continue; + + Rect2d bj = boxes.get(j); + double areaJ = bj.width() * bj.height(); + + double xx1 = Math.max(bi.x(), bj.x()); + double yy1 = Math.max(bi.y(), bj.y()); + double xx2 = Math.min(bi.x() + bi.width(), bj.x() + bj.width()); + double yy2 = Math.min(bi.y() + bi.height(), bj.y() + bj.height()); + + double w = Math.max(0, xx2 - xx1); + double h = Math.max(0, yy2 - yy1); + double inter = w * h; + double iou = inter / (areaI + areaJ - inter + 1e-9); + + if (iou > nmsThreshold) removed[j] = true; + } + } + return keep; + } + + @Override public void close(){ net.close(); } +} \ No newline at end of file diff --git a/ruoyi-video/src/main/resources/libs/models/models.json b/ruoyi-video/src/main/resources/libs/models/models.json index 14a7b7c..38e5aa4 100644 --- a/ruoyi-video/src/main/resources/libs/models/models.json +++ b/ruoyi-video/src/main/resources/libs/models/models.json @@ -1,4 +1,4 @@ [ - {"name":"smoke","path":"libs/models/smoke","size":[640,640],"backend":"OpenCV"}, - {"name":"garbage","path":"libs/models/garbage","size":[640,640],"backend":"OpenCV"} + {"name":"smoke","path":"libs/models/smoke","size":[640,640],"backend":"opencv"}, + {"name":"garbage","path":"libs/models/garbage","size":[640,640],"backend":"opencv"} ]