375 lines
15 KiB
Java
375 lines
15 KiB
Java
package com.klp.service.impl;
|
||
|
||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||
import com.klp.common.config.ImageRecognitionConfig;
|
||
import com.klp.domain.bo.ImageRecognitionBo;
|
||
import com.klp.domain.vo.AttributeVo;
|
||
import com.klp.domain.vo.ImageRecognitionVo;
|
||
import com.klp.service.IImageRecognitionService;
|
||
import com.klp.utils.ImageProcessingUtils;
|
||
import lombok.RequiredArgsConstructor;
|
||
import lombok.extern.slf4j.Slf4j;
|
||
import org.springframework.beans.factory.annotation.Qualifier;
|
||
import org.springframework.http.*;
|
||
import org.springframework.stereotype.Service;
|
||
import org.springframework.web.client.RestTemplate;
|
||
|
||
import java.util.*;
|
||
import java.util.concurrent.CompletableFuture;
|
||
import java.util.concurrent.ExecutorService;
|
||
import java.util.concurrent.Executors;
|
||
import java.util.regex.Matcher;
|
||
import java.util.regex.Pattern;
|
||
|
||
/**
|
||
* 图片识别服务实现类
|
||
*
|
||
* @author klp
|
||
* @date 2025-01-27
|
||
*/
|
||
@Slf4j
|
||
@RequiredArgsConstructor
|
||
@Service
|
||
public class ImageRecognitionServiceImpl implements IImageRecognitionService {
|
||
|
||
private final ImageRecognitionConfig config;
|
||
private final ImageProcessingUtils imageProcessingUtils;
|
||
|
||
@Qualifier("salesScriptRestTemplate")
|
||
private final RestTemplate restTemplate;
|
||
|
||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||
private final ExecutorService executorService = Executors.newFixedThreadPool(5);
|
||
|
||
@Override
|
||
public ImageRecognitionVo recognizeImage(ImageRecognitionBo bo) {
|
||
long startTime = System.currentTimeMillis();
|
||
ImageRecognitionVo result = new ImageRecognitionVo();
|
||
|
||
try {
|
||
// 验证图片URL
|
||
if (!imageProcessingUtils.isValidImageUrl(bo.getImageUrl())) {
|
||
throw new RuntimeException("无效的图片URL");
|
||
}
|
||
|
||
// 根据识别类型调用不同的识别方法
|
||
switch (bo.getRecognitionType()) {
|
||
case "bom":
|
||
result = recognizeBom(bo);
|
||
break;
|
||
case "text":
|
||
result = recognizeText(bo);
|
||
break;
|
||
default:
|
||
result = recognizeGeneral(bo);
|
||
break;
|
||
}
|
||
|
||
result.setStatus("success");
|
||
result.setProcessingTime(System.currentTimeMillis() - startTime);
|
||
|
||
} catch (Exception e) {
|
||
log.error("图片识别失败", e);
|
||
result.setStatus("failed");
|
||
result.setErrorMessage(e.getMessage());
|
||
result.setProcessingTime(System.currentTimeMillis() - startTime);
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
|
||
public ImageRecognitionVo recognizeBom(ImageRecognitionBo bo) {
|
||
String prompt = buildBomPrompt(bo);
|
||
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
|
||
|
||
ImageRecognitionVo result = new ImageRecognitionVo();
|
||
result.setImageUrl(bo.getImageUrl());
|
||
result.setRecognitionType("bom");
|
||
|
||
// 直接解析属性数组
|
||
List<AttributeVo> attributes = parseAttributesResponse(aiResponse);
|
||
result.setAttributes(attributes);
|
||
|
||
// 构建结构化结果
|
||
Map<String, Object> structuredResult = new HashMap<>();
|
||
structuredResult.put("attributes", attributes);
|
||
structuredResult.put("summary", "材料质保单识别结果");
|
||
structuredResult.put("totalItems", attributes.size());
|
||
|
||
|
||
|
||
return result;
|
||
}
|
||
|
||
public ImageRecognitionVo recognizeText(ImageRecognitionBo bo) {
|
||
String prompt = buildTextPrompt(bo);
|
||
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
|
||
|
||
ImageRecognitionVo result = new ImageRecognitionVo();
|
||
result.setImageUrl(bo.getImageUrl());
|
||
result.setRecognitionType("text");
|
||
|
||
return result;
|
||
}
|
||
|
||
|
||
/**
|
||
* 通用识别方法
|
||
*/
|
||
private ImageRecognitionVo recognizeGeneral(ImageRecognitionBo bo) {
|
||
String prompt = buildGeneralPrompt(bo);
|
||
String aiResponse = callAiApi(bo.getImageUrl(), prompt, bo.getEnableVoting(), bo.getVotingRounds());
|
||
|
||
ImageRecognitionVo result = new ImageRecognitionVo();
|
||
result.setImageUrl(bo.getImageUrl());
|
||
result.setRecognitionType("general");
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* 调用AI API
|
||
*/
|
||
private String callAiApi(String imageUrl, String prompt, Boolean enableVoting, Integer votingRounds) {
|
||
// 转换图片为Data URI
|
||
String dataUri = imageProcessingUtils.imageUrlToDataUri(
|
||
imageUrl, config.getMaxImageDimension(), config.getImageQuality());
|
||
|
||
// 构建请求体
|
||
Map<String, Object> requestBody = new HashMap<>();
|
||
requestBody.put("model", config.getModelName());
|
||
|
||
List<Map<String, Object>> contents = new ArrayList<>();
|
||
|
||
// 添加图片内容
|
||
Map<String, Object> imageContent = new HashMap<>();
|
||
imageContent.put("type", "image_url");
|
||
Map<String, Object> imageUrlObj = new HashMap<>();
|
||
imageUrlObj.put("url", dataUri);
|
||
imageUrlObj.put("detail", "low");
|
||
imageContent.put("image_url", imageUrlObj);
|
||
contents.add(imageContent);
|
||
|
||
// 添加文本内容
|
||
Map<String, Object> textContent = new HashMap<>();
|
||
textContent.put("type", "text");
|
||
textContent.put("text", prompt);
|
||
contents.add(textContent);
|
||
|
||
Map<String, Object> message = new HashMap<>();
|
||
message.put("role", "user");
|
||
message.put("content", contents);
|
||
|
||
requestBody.put("messages", Arrays.asList(message));
|
||
requestBody.put("enable_thinking", true);
|
||
requestBody.put("temperature", config.getTemperature());
|
||
requestBody.put("top_p", 0.7);
|
||
requestBody.put("min_p", 0.05);
|
||
requestBody.put("frequency_penalty", 0.2);
|
||
requestBody.put("max_token", config.getMaxTokens());
|
||
requestBody.put("stream", false);
|
||
requestBody.put("stop", new ArrayList<>());
|
||
Map<String, String> responseFormat = new HashMap<>();
|
||
responseFormat.put("type", "text");
|
||
requestBody.put("response_format", responseFormat);
|
||
|
||
// 多轮投票处理
|
||
if (Boolean.TRUE.equals(enableVoting) && votingRounds > 1) {
|
||
return callAiApiWithVoting(requestBody, votingRounds);
|
||
} else {
|
||
return callAiApiSingle(requestBody);
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 单次调用AI API
|
||
*/
|
||
private String callAiApiSingle(Map<String, Object> requestBody) {
|
||
for (int i = 0; i < config.getMaxRetries(); i++) {
|
||
try {
|
||
HttpHeaders headers = new HttpHeaders();
|
||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||
headers.setBearerAuth(config.getApiKey());
|
||
|
||
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(requestBody, headers);
|
||
ResponseEntity<Map> response = restTemplate.postForEntity(
|
||
config.getApiUrl(), entity, Map.class);
|
||
|
||
if (response.getStatusCode() == HttpStatus.OK && response.getBody() != null) {
|
||
Map<String, Object> body = response.getBody();
|
||
List<Map<String, Object>> choices = (List<Map<String, Object>>) body.get("choices");
|
||
if (choices != null && !choices.isEmpty()) {
|
||
Map<String, Object> choice = choices.get(0);
|
||
Map<String, Object> message = (Map<String, Object>) choice.get("message");
|
||
return (String) message.get("content");
|
||
}
|
||
}
|
||
} catch (Exception e) {
|
||
log.error("AI API调用失败,重试 {}: {}", i + 1, e.getMessage());
|
||
if (i == config.getMaxRetries() - 1) {
|
||
throw new RuntimeException("AI API调用失败", e);
|
||
}
|
||
}
|
||
}
|
||
throw new RuntimeException("AI API调用失败,已达到最大重试次数");
|
||
}
|
||
|
||
/**
|
||
* 多轮投票调用AI API
|
||
*/
|
||
private String callAiApiWithVoting(Map<String, Object> requestBody, int rounds) {
|
||
List<CompletableFuture<String>> futures = new ArrayList<>();
|
||
|
||
for (int i = 0; i < rounds; i++) {
|
||
CompletableFuture<String> future = CompletableFuture.supplyAsync(() ->
|
||
callAiApiSingle(requestBody), executorService);
|
||
futures.add(future);
|
||
}
|
||
|
||
List<String> results = new ArrayList<>();
|
||
for (CompletableFuture<String> future : futures) {
|
||
try {
|
||
results.add(future.get());
|
||
} catch (Exception e) {
|
||
log.error("投票轮次失败: {}", e.getMessage());
|
||
}
|
||
}
|
||
|
||
if (results.isEmpty()) {
|
||
throw new RuntimeException("所有投票轮次都失败了");
|
||
}
|
||
|
||
// 简单投票:返回第一个成功的结果
|
||
return results.get(0);
|
||
}
|
||
|
||
/**
|
||
* 构建BOM识别提示词
|
||
*/
|
||
private String buildBomPrompt(ImageRecognitionBo bo) {
|
||
StringBuilder prompt = new StringBuilder();
|
||
prompt.append("这是一张材料质保单的图片。请从表格、文字排布中提取有用的信息,返回其中所有可以被识别为字段名 + 字段值的键值对。\n\n");
|
||
prompt.append("要求如下:\n\n");
|
||
prompt.append("1. 忽略标题、编号、日期等通用信息,不要作为键值对输出;\n");
|
||
prompt.append("2. 仅提取\"字段名 + 字段值\"形式的内容,且字段名可以是如\"钢卷号\"、\"规格\"、\"净重\"、\"材质\"、\"下工序\"、\"生产班组\"、\"生产日期\"等;\n");
|
||
prompt.append("3. 字段不固定,根据图像内容自行判断,但要尽量提取所有;\n");
|
||
prompt.append("4. 返回格式统一为 JSON 数组,格式如下:\n");
|
||
prompt.append("[\n");
|
||
prompt.append(" { \"attrKey\": \"字段名\", \"attrValue\": \"字段值\" },\n");
|
||
prompt.append(" ...\n");
|
||
prompt.append("]\n");
|
||
prompt.append("5. 如有值缺失或为空的字段,仍保留字段,value 留空字符串;\n");
|
||
prompt.append("6. 严格按照图像中文字布局顺序返回;\n");
|
||
prompt.append("7. 只输出 JSON 结果,不需要解释或说明;\n\n");
|
||
|
||
if (bo.getProductId() != null) {
|
||
prompt.append("【产品信息】\n");
|
||
prompt.append("产品ID: ").append(bo.getProductId()).append("\n\n");
|
||
}
|
||
|
||
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
|
||
prompt.append("【自定义要求】\n");
|
||
prompt.append(bo.getCustomPrompt()).append("\n\n");
|
||
}
|
||
|
||
return prompt.toString();
|
||
}
|
||
|
||
/**
|
||
* 构建文字识别提示词
|
||
*/
|
||
private String buildTextPrompt(ImageRecognitionBo bo) {
|
||
StringBuilder prompt = new StringBuilder();
|
||
prompt.append("请识别图片中的所有文字内容,包括但不限于:\n\n");
|
||
prompt.append("【识别要求】\n");
|
||
prompt.append("1. 识别图片中的所有可见文字\n");
|
||
prompt.append("2. 保持文字的原始格式和顺序\n");
|
||
prompt.append("3. 识别表格、列表等结构化内容\n");
|
||
prompt.append("4. 识别数字、符号等特殊字符\n");
|
||
prompt.append("5. 保持段落和换行格式\n\n");
|
||
|
||
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
|
||
prompt.append("【自定义要求】\n");
|
||
prompt.append(bo.getCustomPrompt()).append("\n\n");
|
||
}
|
||
|
||
prompt.append("【输出格式】\n");
|
||
prompt.append("请直接输出识别到的文字内容,保持原有格式。");
|
||
|
||
return prompt.toString();
|
||
}
|
||
|
||
/**
|
||
* 构建通用识别提示词
|
||
*/
|
||
private String buildGeneralPrompt(ImageRecognitionBo bo) {
|
||
StringBuilder prompt = new StringBuilder();
|
||
prompt.append("请分析这张图片的内容,并提供详细描述:\n\n");
|
||
prompt.append("【分析要求】\n");
|
||
prompt.append("1. 描述图片的主要内容和主题\n");
|
||
prompt.append("2. 识别图片中的文字信息\n");
|
||
prompt.append("3. 分析图片的结构和布局\n");
|
||
prompt.append("4. 提取关键信息和数据\n");
|
||
prompt.append("5. 识别图片中的表格、图表等结构化内容\n\n");
|
||
|
||
if (bo.getCustomPrompt() != null && !bo.getCustomPrompt().isEmpty()) {
|
||
prompt.append("【自定义要求】\n");
|
||
prompt.append(bo.getCustomPrompt()).append("\n\n");
|
||
}
|
||
|
||
prompt.append("【输出格式】\n");
|
||
prompt.append("请提供详细的分析结果,包括文字内容、结构分析等。");
|
||
|
||
return prompt.toString();
|
||
}
|
||
|
||
/**
|
||
* 解析属性数组响应
|
||
*/
|
||
private List<AttributeVo> parseAttributesResponse(String response) {
|
||
try {
|
||
// 尝试直接解析JSON数组
|
||
List<Map<String, Object>> attrList = objectMapper.readValue(response,
|
||
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
|
||
|
||
List<AttributeVo> attributes = new ArrayList<>();
|
||
for (Map<String, Object> attr : attrList) {
|
||
AttributeVo attribute = new AttributeVo();
|
||
attribute.setAttrKey((String) attr.get("attrKey"));
|
||
attribute.setAttrValue((String) attr.get("attrValue"));
|
||
attributes.add(attribute);
|
||
}
|
||
return attributes;
|
||
|
||
} catch (JsonProcessingException e) {
|
||
// 如果直接解析失败,尝试提取JSON数组部分
|
||
Pattern jsonArrayPattern = Pattern.compile("\\[[\\s\\S]*\\]");
|
||
Matcher matcher = jsonArrayPattern.matcher(response);
|
||
if (matcher.find()) {
|
||
try {
|
||
List<Map<String, Object>> attrList = objectMapper.readValue(matcher.group(),
|
||
objectMapper.getTypeFactory().constructCollectionType(List.class, Map.class));
|
||
|
||
List<AttributeVo> attributes = new ArrayList<>();
|
||
for (Map<String, Object> attr : attrList) {
|
||
AttributeVo attribute = new AttributeVo();
|
||
attribute.setAttrKey((String) attr.get("attrKey"));
|
||
attribute.setAttrValue((String) attr.get("attrValue"));
|
||
attributes.add(attribute);
|
||
}
|
||
return attributes;
|
||
|
||
} catch (JsonProcessingException ex) {
|
||
log.warn("无法解析属性响应为JSON数组: {}", response);
|
||
return new ArrayList<>();
|
||
}
|
||
}
|
||
log.warn("无法解析属性响应: {}", response);
|
||
return new ArrayList<>();
|
||
}
|
||
}
|
||
|
||
|
||
}
|