Files
klp-oa/klp-wms/src/main/java/com/klp/service/impl/ImageRecognitionServiceImpl.java
2025-08-02 16:40:16 +08:00

375 lines
15 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.klp.service.impl;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.klp.common.config.ImageRecognitionConfig;
import com.klp.domain.bo.ImageRecognitionBo;
import com.klp.domain.vo.AttributeVo;
import com.klp.domain.vo.ImageRecognitionVo;
import com.klp.service.IImageRecognitionService;
import com.klp.utils.ImageProcessingUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.*;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 图片识别服务实现类
*
* @author klp
* @date 2025-01-27
*/
@Slf4j
@RequiredArgsConstructor
@Service
public class ImageRecognitionServiceImpl implements IImageRecognitionService {
private final ImageRecognitionConfig config;
private final ImageProcessingUtils imageProcessingUtils;
@Qualifier("salesScriptRestTemplate")
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper = new ObjectMapper();
private final ExecutorService executorService = Executors.newFixedThreadPool(5);
@Override
public ImageRecognitionVo recognizeImage(ImageRecognitionBo bo) {
long startTime = System.currentTimeMillis();
ImageRecognitionVo result = new ImageRecognitionVo();
try {
// 验证图片URL
if (!imageProcessingUtils.isValidImageUrl(bo.getImageUrl())) {
throw new RuntimeException("无效的图片URL");
}
// 根据识别类型调用不同的识别方法
switch (bo.getRecognitionType()) {
case "bom":
result = recognizeBom(bo);
break;
case "text":
result = recognizeText(bo);
break;
default:
result = recognizeGeneral(bo);
break;
}
result.setStatus("success");
result.setProcessingTime(System.currentTimeMillis() - startTime);
} catch (Exception e) {
log.error("图片识别失败", e);
result.setStatus("failed");
result.setErrorMessage(e.getMessage());
result.setProcessingTime(System.currentTimeMillis() - startTime);
}
return result;
}
public ImageRecognitionVo recognizeBom(ImageRecognitionBo bo) {
String prompt = buildBomPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("bom");
// 直接解析属性数组
List<AttributeVo> attributes = parseAttributesResponse(aiResponse);
result.setAttributes(attributes);
// 构建结构化结果
Map<String, Object> structuredResult = new HashMap<>();
structuredResult.put("attributes", attributes);
structuredResult.put("summary", "材料质保单识别结果");
structuredResult.put("totalItems", attributes.size());
return result;
}
public ImageRecognitionVo recognizeText(ImageRecognitionBo bo) {
String prompt = buildTextPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("text");
return result;
}
/**
* 通用识别方法
*/
private ImageRecognitionVo recognizeGeneral(ImageRecognitionBo bo) {
String prompt = buildGeneralPrompt(bo);
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
ImageRecognitionVo result = new ImageRecognitionVo();
result.setImageUrl(bo.getImageUrl());
result.setRecognitionType("general");
return result;
}
/**
* 调用AI API
*/
private String callAiApi(String imageUrl, String prompt, Boolean enableVoting, Integer votingRounds) {
// 转换图片为Data URI
String dataUri = imageProcessingUtils.imageUrlToDataUri(
imageUrl, config.getMaxImageDimension(), config.getImageQuality());
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", config.getModelName());
List<Map<String, Object>> contents = new ArrayList<>();
// 添加图片内容
Map<String, Object> imageContent = new HashMap<>();
imageContent.put("type", "image_url");
Map<String, Object> imageUrlObj = new HashMap<>();
imageUrlObj.put("url", dataUri);
imageUrlObj.put("detail", "low");
imageContent.put("image_url", imageUrlObj);
contents.add(imageContent);
// 添加文本内容
Map<String, Object> textContent = new HashMap<>();
textContent.put("type", "text");
textContent.put("text", prompt);
contents.add(textContent);
Map<String, Object> message = new HashMap<>();
message.put("role", "user");
message.put("content", contents);
requestBody.put("messages", Arrays.asList(message));
requestBody.put("enable_thinking", true);
requestBody.put("temperature", config.getTemperature());
requestBody.put("top_p", 0.7);
requestBody.put("min_p", 0.05);
requestBody.put("frequency_penalty", 0.2);
requestBody.put("max_token", config.getMaxTokens());
requestBody.put("stream", false);
requestBody.put("stop", new ArrayList<>());
Map<String, String> responseFormat = new HashMap<>();
responseFormat.put("type", "text");
requestBody.put("response_format", responseFormat);
// 多轮投票处理
if (Boolean.TRUE.equals(enableVoting) && votingRounds > 1) {
return callAiApiWithVoting(requestBody, votingRounds);
} else {
return callAiApiSingle(requestBody);
}
}
/**
* 单次调用AI API
*/
private String callAiApiSingle(Map<String, Object> requestBody) {
for (int i = 0; i < config.getMaxRetries(); i++) {
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(config.getApiKey());
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(requestBody, headers);
ResponseEntity<Map> response = restTemplate.postForEntity(
config.getApiUrl(), entity, Map.class);
if (response.getStatusCode() == HttpStatus.OK && response.getBody() != null) {
Map<String, Object> body = response.getBody();
List<Map<String, Object>> choices = (List<Map<String, Object>>) body.get("choices");
if (choices != null && !choices.isEmpty()) {
Map<String, Object> choice = choices.get(0);
Map<String, Object> message = (Map<String, Object>) choice.get("message");
return (String) message.get("content");
}
}
} catch (Exception e) {
log.error("AI API调用失败重试 {}: {}", i + 1, e.getMessage());
if (i == config.getMaxRetries() - 1) {
throw new RuntimeException("AI API调用失败", e);
}
}
}
throw new RuntimeException("AI API调用失败已达到最大重试次数");
}
/**
* 多轮投票调用AI API
*/
private String callAiApiWithVoting(Map<String, Object> requestBody, int rounds) {
List<CompletableFuture<String>> futures = new ArrayList<>();
for (int i = 0; i < rounds; i++) {
CompletableFuture<String> future = CompletableFuture.supplyAsync(() ->
callAiApiSingle(requestBody), executorService);
futures.add(future);
}
List<String> results = new ArrayList<>();
for (CompletableFuture<String> future : futures) {
try {
results.add(future.get());
} catch (Exception e) {
log.error("投票轮次失败: {}", e.getMessage());
}
}
if (results.isEmpty()) {
throw new RuntimeException("所有投票轮次都失败了");
}
// 简单投票:返回第一个成功的结果
return results.get(0);
}
/**
* 构建BOM识别提示词
*/
private String buildBomPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
prompt.append("这是一张材料质保单的图片。请从表格、文字排布中提取有用的信息,返回其中所有可以被识别为字段名 + 字段值的键值对。\n\n");
prompt.append("要求如下:\n\n");
prompt.append("1. 忽略标题、编号、日期等通用信息,不要作为键值对输出;\n");
prompt.append("2. 仅提取\"字段名 + 字段值\"形式的内容,且字段名可以是如\"钢卷号\"\"规格\"\"净重\"\"材质\"\"下工序\"\"生产班组\"\"生产日期\"等;\n");
prompt.append("3. 字段不固定,根据图像内容自行判断,但要尽量提取所有;\n");
prompt.append("4. 返回格式统一为 JSON 数组,格式如下:\n");
prompt.append("[\n");
prompt.append(" { \"attrKey\": \"字段名\", \"attrValue\": \"字段值\" },\n");
prompt.append(" ...\n");
prompt.append("]\n");
prompt.append("5. 如有值缺失或为空的字段仍保留字段value 留空字符串;\n");
prompt.append("6. 严格按照图像中文字布局顺序返回;\n");
prompt.append("7. 只输出 JSON 结果,不需要解释或说明;\n\n");
if (bo.getProductId() != null) {
prompt.append("【产品信息】\n");
prompt.append("产品ID: ").append(bo.getProductId()).append("\n\n");
}
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
return prompt.toString();
}
/**
* 构建文字识别提示词
*/
private String buildTextPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
prompt.append("请识别图片中的所有文字内容,包括但不限于:\n\n");
prompt.append("【识别要求】\n");
prompt.append("1. 识别图片中的所有可见文字\n");
prompt.append("2. 保持文字的原始格式和顺序\n");
prompt.append("3. 识别表格、列表等结构化内容\n");
prompt.append("4. 识别数字、符号等特殊字符\n");
prompt.append("5. 保持段落和换行格式\n\n");
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
prompt.append("【输出格式】\n");
prompt.append("请直接输出识别到的文字内容,保持原有格式。");
return prompt.toString();
}
/**
* 构建通用识别提示词
*/
private String buildGeneralPrompt(ImageRecognitionBo bo) {
StringBuilder prompt = new StringBuilder();
prompt.append("请分析这张图片的内容,并提供详细描述:\n\n");
prompt.append("【分析要求】\n");
prompt.append("1. 描述图片的主要内容和主题\n");
prompt.append("2. 识别图片中的文字信息\n");
prompt.append("3. 分析图片的结构和布局\n");
prompt.append("4. 提取关键信息和数据\n");
prompt.append("5. 识别图片中的表格、图表等结构化内容\n\n");
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
prompt.append("【自定义要求】\n");
prompt.append(bo.getCustomPrompt()).append("\n\n");
}
prompt.append("【输出格式】\n");
prompt.append("请提供详细的分析结果,包括文字内容、结构分析等。");
return prompt.toString();
}
/**
* 解析属性数组响应
*/
private List<AttributeVo> parseAttributesResponse(String response) {
try {
// 尝试直接解析JSON数组
List<Map<String, Object>> attrList = objectMapper.readValue(response,
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
List<AttributeVo> attributes = new ArrayList<>();
for (Map<String, Object> attr : attrList) {
AttributeVo attribute = new AttributeVo();
attribute.setAttrKey((String) attr.get("attrKey"));
attribute.setAttrValue((String) attr.get("attrValue"));
attributes.add(attribute);
}
return attributes;
} catch (JsonProcessingException e) {
// 如果直接解析失败尝试提取JSON数组部分
Pattern jsonArrayPattern = Pattern.compile("\\[[\\s\\S]*\\]");
Matcher matcher = jsonArrayPattern.matcher(response);
if (matcher.find()) {
try {
List<Map<String, Object>> attrList = objectMapper.readValue(matcher.group(),
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
List<AttributeVo> attributes = new ArrayList<>();
for (Map<String, Object> attr : attrList) {
AttributeVo attribute = new AttributeVo();
attribute.setAttrKey((String) attr.get("attrKey"));
attribute.setAttrValue((String) attr.get("attrValue"));
attributes.add(attribute);
}
return attributes;
} catch (JsonProcessingException ex) {
log.warn("无法解析属性响应为JSON数组: {}", response);
return new ArrayList<>();
}
}
log.warn("无法解析属性响应: {}", response);
return new ArrayList<>();
}
}
}