修复bug

This commit is contained in:
2025-09-26 17:29:30 +08:00
parent 456c7f4a01
commit 03d24749ea

View File

@@ -1,15 +1,16 @@
package com.ruoyi.video.thread.detector;
import com.ruoyi.video.domain.Detection;
import org.bytedeco.javacpp.indexer.FloatRawIndexer;
import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_dnn.Net;
import java.nio.file.*;
import java.util.*;
import static org.bytedeco.opencv.global.Dnn.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.bytedeco.opencv.global.opencv_dnn.*; // DNN APIreadNetFromModelOptimizer / blobFromImage 等)
import static org.bytedeco.opencv.global.opencv_core.*; // Mat/Size/Scalar/transpose 等
import static org.bytedeco.opencv.global.opencv_imgproc.*; // cvtColor 等
public final class OpenVinoYoloDetector implements YoloDetector {
private final String modelName;
@@ -36,10 +37,16 @@ public final class OpenVinoYoloDetector implements YoloDetector {
}
this.net = readNetFromModelOptimizer(xml, bin);
boolean set = false;
if ("openvino".equalsIgnoreCase(backend)) {
net.setPreferableBackend(DNN_BACKEND_INFERENCE_ENGINE);
net.setPreferableTarget(DNN_TARGET_CPU);
} else {
try {
net.setPreferableBackend(DNN_BACKEND_INFERENCE_ENGINE);
net.setPreferableTarget(DNN_TARGET_CPU);
set = true;
} catch (Throwable ignore) { /* 回退 */ }
}
if (!set) {
net.setPreferableBackend(DNN_BACKEND_OPENCV);
net.setPreferableTarget(DNN_TARGET_CPU);
}
@@ -49,60 +56,175 @@ public final class OpenVinoYoloDetector implements YoloDetector {
@Override
public List<Detection> detect(Mat bgr) {
if (bgr == null || bgr.empty()) return Collections.emptyList();
// 统一成 BGR 3 通道,避免 blobFromImage 断言失败
if (bgr.channels() != 3) {
Mat tmp = new Mat();
if (bgr.channels() == 1) cvtColor(bgr, tmp, COLOR_GRAY2BGR);
else if (bgr.channels() == 4) cvtColor(bgr, tmp, COLOR_BGRA2BGR);
else bgr.copyTo(tmp);
bgr = tmp;
}
try (Mat blob = blobFromImage(bgr, 1.0/255.0, input, new Scalar(0.0), true, false, CV_32F)) {
net.setInput(blob);
Mat out = new Mat();
net.forward(out); // 常见: [1,N,C] or [N,C]
Mat m = out.reshape(1, (int)out.total() / out.size(2));
// ===== 多输出兼容Bytedeco 正确写法)=====
org.bytedeco.opencv.opencv_core.StringVector outNames = net.getUnconnectedOutLayersNames();
List<Mat> outs = new ArrayList<>();
if (outNames == null || outNames.size() == 0) {
// 只有一个默认输出
Mat out = net.forward(); // ← 直接返回 Mat
outs.add(out);
} else {
// 多输出:用 MatVector 承接
org.bytedeco.opencv.opencv_core.MatVector outBlobs =
new org.bytedeco.opencv.opencv_core.MatVector(outNames.size());
net.forward(outBlobs, outNames); // ← 正确的重载
for (long i = 0; i < outBlobs.size(); i++) {
outs.add(outBlobs.get(i));
}
}
int fw = bgr.cols(), fh = bgr.rows();
FloatRawIndexer idx = m.createIndexer();
int N = m.rows(), C = m.cols();
List<Rect2d> boxes = new ArrayList<>();
List<Float> scores = new ArrayList<>();
List<Integer> classIds = new ArrayList<>();
for (int i = 0; i < N; i++) {
float cx = idx.get(i,0), cy=idx.get(i,1), w=idx.get(i,2), h=idx.get(i,3);
float obj = idx.get(i,4);
int best=-1; float pmax=0f;
for (int c=5;c<C;c++) { float p=idx.get(i,c); if (p>pmax){pmax=p; best=c-5;} }
float conf = obj * pmax;
if (conf < confTh) continue;
int bx = Math.max(0, Math.round(cx*fw - (w*fw)/2f));
int by = Math.max(0, Math.round(cy*fh - (h*fh)/2f));
int bw = Math.min(fw-bx, Math.round(w*fw));
int bh = Math.min(fh-by, Math.round(h*fh));
if (bw<=0 || bh<=0) continue;
boxes.add(new Rect2d(bx,by,bw,bh));
scores.add(conf);
classIds.add(best);
for (Mat out : outs) {
parseYoloOutput(out, fw, fh, boxes, scores, classIds);
}
if (boxes.isEmpty()) return Collections.emptyList();
// NMS
MatOfRect2d b = new MatOfRect2d(boxes.toArray(new Rect2d[0]));
MatOfFloat s = new MatOfFloat(toArray(scores));
MatOfInt keep = new MatOfInt();
NMSBoxes(b, s, confTh, nmsTh, keep);
// 纯 Java NMS避免 MatOf* / Vector API 兼容问题
List<Integer> keep = nmsIndices(boxes, scores, nmsTh);
List<Detection> outList = new ArrayList<>();
IntRawIndexer kidx = keep.createIndexer();
for (int i=0;i<keep.rows();i++){
int k = kidx.get(i);
List<Detection> result = new ArrayList<>(keep.size());
for (int k : keep) {
Rect2d r = boxes.get(k);
Rect rect = new Rect((int)r.x(), (int)r.y(), (int)r.width(), (int)r.height());
String cname = (classIds.get(k)>=0 && classIds.get(k)<classes.length)
? classes[classIds.get(k)] : "cls"+classIds.get(k);
outList.add(new Detection("["+modelName+"] "+cname, scores.get(k), rect, colorBGR));
int cid = classIds.get(k);
String cname = (cid >= 0 && cid < classes.length) ? classes[cid] : ("cls"+cid);
result.add(new Detection("["+modelName+"] "+cname, scores.get(k), rect, colorBGR));
}
return outList;
return result;
} catch (Throwable e) {
// 单帧失败不影响整体
return Collections.emptyList();
}
}
private static float[] toArray(List<Float> ls){ float[] a=new float[ls.size()]; for(int i=0;i<ls.size();i++) a[i]=ls.get(i); return a; }
/** 解析 YOLO-IR 输出为 N×CC>=6并填充 boxes/scores/classIds。 */
private void parseYoloOutput(Mat out, int fw, int fh,
List<Rect2d> boxes, List<Float> scores, List<Integer> classIds) {
int dims = out.dims();
Mat m;
if (dims == 2) {
// NxC 或 CxN
if (out.cols() >= 6) {
m = out;
} else {
Mat tmp = new Mat();
transpose(out, tmp); // CxN -> NxC
m = tmp;
}
} else if (dims == 3) {
// [1,N,C] 或 [1,C,N]
if (out.size(2) >= 6) {
m = out.reshape(1, out.size(1)); // -> N×C
} else {
Mat squeezed = out.reshape(1, out.size(1)); // C×N
Mat tmp = new Mat();
transpose(squeezed, tmp); // -> N×C
m = tmp;
}
} else if (dims == 4) {
// [1,1,N,C] 或 [1,1,C,N]
int a = out.size(2), b = out.size(3);
if (b >= 6) {
m = out.reshape(1, a).clone(); // -> N×C
} else {
Mat cxn = out.reshape(1, b); // C×N
Mat tmp = new Mat();
transpose(cxn, tmp); // -> N×C
m = tmp.clone();
}
} else {
return; // 不支持的形状
}
int N = m.rows(), C = m.cols();
if (C < 6 || N <= 0) return;
FloatRawIndexer idx = m.createIndexer();
for (int i = 0; i < N; i++) {
float cx = idx.get(i,0), cy = idx.get(i,1), w = idx.get(i,2), h = idx.get(i,3);
float obj = idx.get(i,4);
int bestCls = -1; float bestScore = 0f;
for (int c = 5; c < C; c++) {
float p = idx.get(i,c);
if (p > bestScore) { bestScore = p; bestCls = c - 5; }
}
float conf = obj * bestScore;
if (conf < confTh) continue;
// 默认假设归一化中心点格式 (cx,cy,w,h);若你的 IR 是 x1,y1,x2,y2请把这里换算改掉
int bx = Math.max(0, Math.round(cx * fw - (w * fw) / 2f));
int by = Math.max(0, Math.round(cy * fh - (h * fh) / 2f));
int bw = Math.min(fw - bx, Math.round(w * fw));
int bh = Math.min(fh - by, Math.round(h * fh));
if (bw <= 0 || bh <= 0) continue;
boxes.add(new Rect2d(bx, by, bw, bh));
scores.add(conf);
classIds.add(bestCls);
}
}
/** 纯 Java NMSIoU 抑制),返回保留的下标列表。 */
private List<Integer> nmsIndices(List<Rect2d> boxes, List<Float> scores, float nmsThreshold) {
List<Integer> order = new ArrayList<>(boxes.size());
for (int i = 0; i < boxes.size(); i++) order.add(i);
// 按分数降序
order.sort((i, j) -> Float.compare(scores.get(j), scores.get(i)));
List<Integer> keep = new ArrayList<>();
boolean[] removed = new boolean[boxes.size()];
for (int a = 0; a < order.size(); a++) {
int i = order.get(a);
if (removed[i]) continue;
keep.add(i);
Rect2d bi = boxes.get(i);
double areaI = bi.width() * bi.height();
for (int b = a + 1; b < order.size(); b++) {
int j = order.get(b);
if (removed[j]) continue;
Rect2d bj = boxes.get(j);
double areaJ = bj.width() * bj.height();
double xx1 = Math.max(bi.x(), bj.x());
double yy1 = Math.max(bi.y(), bj.y());
double xx2 = Math.min(bi.x() + bi.width(), bj.x() + bj.width());
double yy2 = Math.min(bi.y() + bi.height(), bj.y() + bj.height());
double w = Math.max(0, xx2 - xx1);
double h = Math.max(0, yy2 - yy1);
double inter = w * h;
double iou = inter / (areaI + areaJ - inter + 1e-9);
if (iou > nmsThreshold) removed[j] = true;
}
}
return keep;
}
@Override public void close(){ net.close(); }
}