Files
klp-oa/klp-wms/src/main/java/com/klp/service/impl/ImageRecognitionServiceImpl.java

375 lines
15 KiB
Java
Raw Normal View History

2025-08-02 14:46:02 +08:00
package com.klp.service.impl;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
2025-08-02 16:40:16 +08:00
import com.klp.common.config.ImageRecognitionConfig;
2025-08-02 14:46:02 +08:00
import com.klp.domain.bo.ImageRecognitionBo;
2025-08-02 15:49:57 +08:00
import com.klp.domain.vo.AttributeVo;
2025-08-02 14:46:02 +08:00
import com.klp.domain.vo.ImageRecognitionVo;
import com.klp.service.IImageRecognitionService;
import com.klp.utils.ImageProcessingUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.*;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 图片识别服务实现类
*
* @author klp
* @date 2025-01-27
*/
@Slf4j
@RequiredArgsConstructor
@Service
public class ImageRecognitionServiceImpl implements IImageRecognitionService {
private final ImageRecognitionConfig config;
private final ImageProcessingUtils imageProcessingUtils;
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
@Qualifier("salesScriptRestTemplate")
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper = new ObjectMapper();
private final ExecutorService executorService = Executors.newFixedThreadPool(5);
@Override
public ImageRecognitionVo recognizeImage(ImageRecognitionBo bo) {
long startTime = System.currentTimeMillis();
ImageRecognitionVo result = new ImageRecognitionVo();
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
try {
// 验证图片URL
if (!imageProcessingUtils.isValidImageUrl(bo.getImageUrl())) {
throw new RuntimeException("无效的图片URL");
}
// 根据识别类型调用不同的识别方法
switch (bo.getRecognitionType()) {
case "bom":
result = recognizeBom(bo);
break;
case "text":
result = recognizeText(bo);
break;
default:
result = recognizeGeneral(bo);
break;
}
result.setStatus("success");
result.setProcessingTime(System.currentTimeMillis() - startTime);
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
} catch (Exception e) {
log.error("图片识别失败", e);
result.setStatus("failed");
result.setErrorMessage(e.getMessage());
result.setProcessingTime(System.currentTimeMillis() - startTime);
}
return result;
}
public ImageRecognitionVo recognizeBom(ImageRecognitionBo bo) {
String prompt = buildBomPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("bom");
2025-08-02 15:49:57 +08:00
2025-08-02 16:40:16 +08:00
// 直接解析属性数组
List<AttributeVo> attributes = parseAttributesResponse(aiResponse);
result.setAttributes(attributes);
2025-08-02 15:49:57 +08:00
2025-08-02 16:40:16 +08:00
// 构建结构化结果
Map<String, Object> structuredResult = new HashMap<>();
structuredResult.put("attributes", attributes);
structuredResult.put("summary", "材料质保单识别结果");
structuredResult.put("totalItems", attributes.size());
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
return result;
}
public ImageRecognitionVo recognizeText(ImageRecognitionBo bo) {
String prompt = buildTextPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("text");
return result;
}
/**
* 通用识别方法
*/
private ImageRecognitionVo recognizeGeneral(ImageRecognitionBo bo) {
String prompt = buildGeneralPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("general");
return result;
}
/**
* 调用AI API
*/
private String callAiApi(String imageUrl, String prompt, Boolean enableVoting, Integer votingRounds) {
// 转换图片为Data URI
String dataUri = imageProcessingUtils.imageUrlToDataUri(
imageUrl, config.getMaxImageDimension(), config.getImageQuality());
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", config.getModelName());
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
List<Map<String, Object>> contents = new ArrayList<>();
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
// 添加图片内容
Map<String, Object> imageContent = new HashMap<>();
imageContent.put("type", "image_url");
Map<String, Object> imageUrlObj = new HashMap<>();
imageUrlObj.put("url", dataUri);
imageUrlObj.put("detail", "low");
imageContent.put("image_url", imageUrlObj);
contents.add(imageContent);
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
// 添加文本内容
Map<String, Object> textContent = new HashMap<>();
textContent.put("type", "text");
textContent.put("text", prompt);
contents.add(textContent);
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
Map<String, Object> message = new HashMap<>();
message.put("role", "user");
message.put("content", contents);
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
requestBody.put("messages", Arrays.asList(message));
requestBody.put("enable_thinking", true);
requestBody.put("temperature", config.getTemperature());
requestBody.put("top_p", 0.7);
requestBody.put("min_p", 0.05);
requestBody.put("frequency_penalty", 0.2);
requestBody.put("max_token", config.getMaxTokens());
requestBody.put("stream", false);
requestBody.put("stop", new ArrayList<>());
Map<String, String> responseFormat = new HashMap<>();
responseFormat.put("type", "text");
requestBody.put("response_format", responseFormat);
// 多轮投票处理
if (Boolean.TRUE.equals(enableVoting) && votingRounds > 1) {
return callAiApiWithVoting(requestBody, votingRounds);
} else {
return callAiApiSingle(requestBody);
}
}
/**
* 单次调用AI API
*/
private String callAiApiSingle(Map<String, Object> requestBody) {
for (int i = 0; i < config.getMaxRetries(); i++) {
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(config.getApiKey());
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(requestBody, headers);
ResponseEntity<Map> response = restTemplate.postForEntity(
config.getApiUrl(), entity, Map.class);
if (response.getStatusCode() == HttpStatus.OK && response.getBody() != null) {
Map<String, Object> body = response.getBody();
List<Map<String, Object>> choices = (List<Map<String, Object>>) body.get("choices");
if (choices != null && !choices.isEmpty()) {
Map<String, Object> choice = choices.get(0);
Map<String, Object> message = (Map<String, Object>) choice.get("message");
return (String) message.get("content");
}
}
} catch (Exception e) {
log.error("AI API调用失败重试 {}: {}", i + 1, e.getMessage());
if (i == config.getMaxRetries() - 1) {
throw new RuntimeException("AI API调用失败", e);
}
}
}
throw new RuntimeException("AI API调用失败已达到最大重试次数");
}
/**
* 多轮投票调用AI API
*/
private String callAiApiWithVoting(Map<String, Object> requestBody, int rounds) {
List<CompletableFuture<String>> futures = new ArrayList<>();
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
for (int i = 0; i < rounds; i++) {
2025-08-02 15:49:57 +08:00
CompletableFuture<String> future = CompletableFuture.supplyAsync(() ->
2025-08-02 14:46:02 +08:00
callAiApiSingle(requestBody), executorService);
futures.add(future);
}
List<String> results = new ArrayList<>();
for (CompletableFuture<String> future : futures) {
try {
results.add(future.get());
} catch (Exception e) {
log.error("投票轮次失败: {}", e.getMessage());
}
}
if (results.isEmpty()) {
throw new RuntimeException("所有投票轮次都失败了");
}
// 简单投票:返回第一个成功的结果
return results.get(0);
}
/**
* 构建BOM识别提示词
*/
private String buildBomPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
2025-08-02 15:18:50 +08:00
prompt.append("这是一张材料质保单的图片。请从表格、文字排布中提取有用的信息,返回其中所有可以被识别为字段名 + 字段值的键值对。\n\n");
prompt.append("要求如下:\n\n");
prompt.append("1. 忽略标题、编号、日期等通用信息,不要作为键值对输出;\n");
prompt.append("2. 仅提取\"字段名 + 字段值\"形式的内容,且字段名可以是如\"钢卷号\"\"规格\"\"净重\"\"材质\"\"下工序\"\"生产班组\"\"生产日期\"等;\n");
prompt.append("3. 字段不固定,根据图像内容自行判断,但要尽量提取所有;\n");
prompt.append("4. 返回格式统一为 JSON 数组,格式如下:\n");
prompt.append("[\n");
prompt.append(" { \"attrKey\": \"字段名\", \"attrValue\": \"字段值\" },\n");
prompt.append(" ...\n");
prompt.append("]\n");
prompt.append("5. 如有值缺失或为空的字段仍保留字段value 留空字符串;\n");
prompt.append("6. 严格按照图像中文字布局顺序返回;\n");
prompt.append("7. 只输出 JSON 结果,不需要解释或说明;\n\n");
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
if (bo.getProductId() != null) {
prompt.append("【产品信息】\n");
prompt.append("产品ID: ").append(bo.getProductId()).append("\n\n");
}
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
return prompt.toString();
}
/**
* 构建文字识别提示词
*/
private String buildTextPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
prompt.append("请识别图片中的所有文字内容,包括但不限于:\n\n");
prompt.append("【识别要求】\n");
prompt.append("1. 识别图片中的所有可见文字\n");
prompt.append("2. 保持文字的原始格式和顺序\n");
prompt.append("3. 识别表格、列表等结构化内容\n");
prompt.append("4. 识别数字、符号等特殊字符\n");
prompt.append("5. 保持段落和换行格式\n\n");
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
prompt.append("【输出格式】\n");
prompt.append("请直接输出识别到的文字内容,保持原有格式。");
return prompt.toString();
}
/**
* 构建通用识别提示词
*/
private String buildGeneralPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
prompt.append("请分析这张图片的内容,并提供详细描述:\n\n");
prompt.append("【分析要求】\n");
prompt.append("1. 描述图片的主要内容和主题\n");
prompt.append("2. 识别图片中的文字信息\n");
prompt.append("3. 分析图片的结构和布局\n");
prompt.append("4. 提取关键信息和数据\n");
prompt.append("5. 识别图片中的表格、图表等结构化内容\n\n");
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
2025-08-02 15:49:57 +08:00
2025-08-02 14:46:02 +08:00
prompt.append("【输出格式】\n");
prompt.append("请提供详细的分析结果,包括文字内容、结构分析等。");
return prompt.toString();
}
/**
2025-08-02 15:18:50 +08:00
* 解析属性数组响应
*/
2025-08-02 15:49:57 +08:00
private List<AttributeVo> parseAttributesResponse(String response) {
2025-08-02 15:18:50 +08:00
try {
// 尝试直接解析JSON数组
2025-08-02 15:49:57 +08:00
List<Map<String, Object>> attrList = objectMapper.readValue(response,
2025-08-02 15:18:50 +08:00
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
2025-08-02 15:49:57 +08:00
List<AttributeVo> attributes = new ArrayList<>();
2025-08-02 15:18:50 +08:00
for (Map<String, Object> attr : attrList) {
2025-08-02 15:49:57 +08:00
AttributeVo attribute = new AttributeVo();
2025-08-02 15:18:50 +08:00
attribute.setAttrKey((String) attr.get("attrKey"));
attribute.setAttrValue((String) attr.get("attrValue"));
attributes.add(attribute);
}
return attributes;
2025-08-02 15:49:57 +08:00
2025-08-02 15:18:50 +08:00
} catch (JsonProcessingException e) {
// 如果直接解析失败尝试提取JSON数组部分
Pattern jsonArrayPattern = Pattern.compile("\\[[\\s\\S]*\\]");
Matcher matcher = jsonArrayPattern.matcher(response);
if (matcher.find()) {
try {
2025-08-02 15:49:57 +08:00
List<Map<String, Object>> attrList = objectMapper.readValue(matcher.group(),
2025-08-02 15:18:50 +08:00
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
2025-08-02 15:49:57 +08:00
List<AttributeVo> attributes = new ArrayList<>();
2025-08-02 15:18:50 +08:00
for (Map<String, Object> attr : attrList) {
2025-08-02 15:49:57 +08:00
AttributeVo attribute = new AttributeVo();
2025-08-02 15:18:50 +08:00
attribute.setAttrKey((String) attr.get("attrKey"));
attribute.setAttrValue((String) attr.get("attrValue"));
attributes.add(attribute);
}
return attributes;
2025-08-02 15:49:57 +08:00
2025-08-02 15:18:50 +08:00
} catch (JsonProcessingException ex) {
log.warn("无法解析属性响应为JSON数组: {}", response);
return new ArrayList<>();
}
}
log.warn("无法解析属性响应: {}", response);
return new ArrayList<>();
}
}
2025-08-02 14:46:02 +08:00
2025-08-02 15:49:57 +08:00
}