feat(oa): 添加AI数据查询功能

- 新增AI数据查询接口和相关服务
- 实现关键词匹配和AI智能识别表功能
- 添加SQL生成和执行逻辑
- 新增动态数据返回格式和字段信息类
- 优化SQL安全性验证
This commit is contained in:
2025-08-05 11:38:17 +08:00
parent 2b30d2186f
commit 845e8cfb1e
9 changed files with 1077 additions and 0 deletions

View File

@@ -0,0 +1,40 @@
package com.ruoyi.oa.service;
import com.ruoyi.oa.domain.bo.OaAiDataQueryBo;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
/**
* AI数据查询服务接口
*
* @author ruoyi
* @date 2024-12-19
*/
public interface IAiDataQueryService {
/**
* 根据用户需求进行AI数据查询
*
* @param queryBo 查询请求
* @return 动态数据格式的查询结果
*/
DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo);
/**
* 根据关键词匹配相关表
*
* @param keywords 关键词
* @return 匹配的表名列表
*/
java.util.List<String> matchTablesByKeywords(String[] keywords);
/**
* 生成AI提示词
*
* @param userQuery 用户查询需求
* @param tableNames 相关表名
* @param tableColumns 表字段信息
* @return AI提示词
*/
String generateAiPrompt(String userQuery, java.util.List<String> tableNames,
java.util.Map<String, java.util.List<com.ruoyi.oa.domain.vo.TableColumnVo>> tableColumns);
}

View File

@@ -0,0 +1,58 @@
package com.ruoyi.oa.service;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import java.util.List;
import java.util.Map;
/**
* 数据库查询服务接口
*
* @author ruoyi
* @date 2024-12-19
*/
public interface IDatabaseQueryService {
/**
* 根据表名获取表字段信息
*
* @param tableName 表名
* @return 字段信息列表
*/
List<TableColumnVo> getTableColumns(String tableName);
/**
* 根据多个表名获取表字段信息
*
* @param tableNames 表名列表
* @return 表名到字段信息的映射
*/
Map<String, List<TableColumnVo>> getTableColumns(List<String> tableNames);
/**
* 执行SQL查询
*
* @param sql SQL语句
* @return 查询结果
*/
List<Map<String, Object>> executeQuery(String sql);
/**
* 执行SQL查询并返回动态数据格式
*
* @param sql SQL语句
* @param tableName 表名(用于生成字段元信息)
* @param includeMeta 是否包含元信息
* @return 动态数据格式
*/
DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta);
/**
* 验证SQL是否为查询语句
*
* @param sql SQL语句
* @return 是否为查询语句
*/
boolean isQuerySql(String sql);
}

View File

@@ -0,0 +1,350 @@
package com.ruoyi.oa.service.impl;
import com.ruoyi.oa.domain.bo.OaAiDataQueryBo;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import com.ruoyi.oa.enums.TableMappingEnum;
import com.ruoyi.oa.service.IAiDataQueryService;
import com.ruoyi.oa.service.IDatabaseQueryService;
import com.ruoyi.oa.utils.AiServiceUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* AI数据查询服务实现类
*
* @author ruoyi
* @date 2024-12-19
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class AiDataQueryServiceImpl implements IAiDataQueryService {
private final IDatabaseQueryService databaseQueryService;
private final AiServiceUtil aiServiceUtil;
@Override
public DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo) {
try {
// 1. 从用户查询中提取关键词
String[] keywords = extractKeywords(queryBo.getQuery());
// 2. 根据关键词匹配相关表
List<String> matchedTables = matchTablesByKeywords(keywords);
// 3. 如果关键词匹配失败使用AI智能识别表
if (matchedTables.isEmpty()) {
log.info("关键词匹配失败使用AI智能识别表");
matchedTables = identifyTablesByAi(queryBo.getQuery());
if (matchedTables.isEmpty()) {
throw new RuntimeException("未找到与查询需求相关的数据表");
}
}
// 4. 获取匹配表的字段信息
Map<String, List<TableColumnVo>> tableColumns = databaseQueryService.getTableColumns(matchedTables);
// 5. 生成AI提示词
String prompt = generateAiPrompt(queryBo.getQuery(), matchedTables, tableColumns);
// 6. 调用AI生成SQL
String aiResponse = aiServiceUtil.callDeepSeek(prompt);
// 7. 提取SQL语句
String sql = extractSqlFromAiResponse(aiResponse);
// 8. 执行SQL查询
String primaryTable = matchedTables.get(0); // 使用第一个匹配的表作为主表
DynamicDataVo result = databaseQueryService.executeQueryWithMeta(sql, primaryTable, queryBo.getIncludeMeta());
// 9. 限制返回记录数
if (result.getData() != null && result.getData().size() > queryBo.getLimit()) {
result.setData(result.getData().subList(0, queryBo.getLimit()));
}
return result;
} catch (Exception e) {
log.error("AI数据查询失败", e);
throw new RuntimeException("AI数据查询失败: " + e.getMessage());
}
}
@Override
public List<String> matchTablesByKeywords(String[] keywords) {
return TableMappingEnum.getTableNamesByKeywords(keywords);
}
@Override
public String generateAiPrompt(String userQuery, List<String> tableNames,
Map<String, List<TableColumnVo>> tableColumns) {
StringBuilder prompt = new StringBuilder();
// 系统角色设定
prompt.append("你是一个专业的数据库查询专家请根据用户的需求和提供的数据库表结构信息生成相应的SQL查询语句。\n\n");
// 重要规则
prompt.append("重要规则:\n");
prompt.append("1. 只能生成SELECT查询语句严禁生成INSERT、UPDATE、DELETE、DROP、CREATE、ALTER等修改数据的语句\n");
prompt.append("2. 生成的SQL必须语法正确可以直接执行\n");
prompt.append("3. 如果涉及多表查询请使用适当的JOIN语句\n");
prompt.append("4. 请根据用户需求添加适当的WHERE条件、ORDER BY、LIMIT等子句\n");
prompt.append("5. 只返回SQL语句不要包含其他解释文字\n\n");
// 用户需求
prompt.append("用户需求:").append(userQuery).append("\n\n");
// 相关表信息
prompt.append("相关数据表信息:\n");
for (String tableName : tableNames) {
prompt.append("表名:").append(tableName).append("\n");
prompt.append("表描述:").append(TableMappingEnum.getDescriptionByTableName(tableName)).append("\n");
List<TableColumnVo> columns = tableColumns.get(tableName);
if (columns != null) {
prompt.append("字段信息:\n");
for (TableColumnVo column : columns) {
prompt.append(" - ").append(column.getColumnName())
.append(" (").append(column.getColumnType()).append(")")
.append(" - ").append(column.getColumnComment() != null ? column.getColumnComment() : "无注释");
if ("1".equals(column.getIsPk())) {
prompt.append(" [主键]");
}
if ("1".equals(column.getIsRequired())) {
prompt.append(" [必填]");
}
prompt.append("\n");
}
}
prompt.append("\n");
}
// 输出要求
prompt.append("请根据以上信息生成SQL查询语句只返回SQL语句本身不要包含任何其他文字。");
return prompt.toString();
}
/**
* 从用户查询中提取关键词
*/
private String[] extractKeywords(String query) {
if (query == null || query.trim().isEmpty()) {
return new String[0];
}
// 移除常见的查询词汇
String cleanedQuery = query.replaceAll("查询|获取|显示|查看|搜索|查找|统计|分析|的|表|情况|信息|数据", "");
// 按空格、标点符号分割
String[] words = cleanedQuery.split("[\\s,,。!?!?]+");
// 过滤掉空字符串和太短的词
List<String> keywords = new ArrayList<>();
for (String word : words) {
if (word != null && word.trim().length() >= 2) {
keywords.add(word.trim());
}
}
// 如果没有提取到关键词,尝试更简单的分割方式
if (keywords.isEmpty()) {
// 直接按字符分割,提取有意义的词组
String[] simpleWords = cleanedQuery.split("");
StringBuilder currentWord = new StringBuilder();
for (String charStr : simpleWords) {
if (charStr.matches("[\\u4e00-\\u9fa5]")) { // 中文字符
currentWord.append(charStr);
} else {
if (currentWord.length() >= 2) {
keywords.add(currentWord.toString());
}
currentWord.setLength(0);
}
}
// 处理最后一个词
if (currentWord.length() >= 2) {
keywords.add(currentWord.toString());
}
}
return keywords.toArray(new String[0]);
}
/**
* 使用AI智能识别相关的数据表
*/
private List<String> identifyTablesByAi(String userQuery) {
try {
// 从数据库中获取所有可用的表信息
Map<String, String> allTableDescriptions = getAvailableTablesFromDatabase();
if (allTableDescriptions.isEmpty()) {
log.warn("未从数据库获取到任何表信息");
return new ArrayList<>();
}
// 构建AI提示词
StringBuilder prompt = new StringBuilder();
prompt.append("你是一个数据库专家,请根据用户的查询需求,从以下数据库表中选择最相关的表名。\n\n");
prompt.append("用户查询:").append(userQuery).append("\n\n");
prompt.append("可用的数据库表:\n");
for (Map.Entry<String, String> entry : allTableDescriptions.entrySet()) {
prompt.append("- ").append(entry.getKey()).append(" (").append(entry.getValue()).append(")\n");
}
prompt.append("\n请根据用户查询需求返回最相关的表名用逗号分隔最多返回3个表");
prompt.append("只返回表名,不要包含其他解释文字。");
// 调用AI获取表名
String aiResponse = aiServiceUtil.callDeepSeek(prompt.toString());
// 解析AI返回的表名
List<String> identifiedTables = parseTableNamesFromAiResponse(aiResponse, allTableDescriptions.keySet());
log.info("AI识别的表{}", identifiedTables);
return identifiedTables;
} catch (Exception e) {
log.error("AI表识别失败", e);
return new ArrayList<>();
}
}
/**
* 从数据库中获取所有可用的表信息
*/
private Map<String, String> getAvailableTablesFromDatabase() {
try {
String sql = "select table_name, table_comment from information_schema.tables " +
"where table_schema = (select database()) " +
"AND table_name NOT LIKE 'ACT_%' " +
"AND table_name NOT LIKE 'xxl_job_%' " +
"AND table_name NOT LIKE 'FLW_%' " +
"AND table_name NOT LIKE 'gen_%' " +
"order by table_name";
List<Map<String, Object>> results = databaseQueryService.executeQuery(sql);
Map<String, String> tableDescriptions = new HashMap<>();
for (Map<String, Object> row : results) {
String tableName = (String) row.get("table_name");
String tableComment = (String) row.get("table_comment");
if (tableName != null) {
// 如果表注释为空,使用枚举中的描述作为备选
String description = (tableComment != null && !tableComment.trim().isEmpty())
? tableComment
: TableMappingEnum.getDescriptionByTableName(tableName);
// 如果枚举中也没有描述,使用表名作为描述
if (description == null || description.trim().isEmpty()) {
description = tableName;
}
tableDescriptions.put(tableName, description);
}
}
log.info("从数据库获取到 {} 个表", tableDescriptions.size());
return tableDescriptions;
} catch (Exception e) {
log.error("从数据库获取表信息失败", e);
// 如果数据库查询失败,回退到枚举中的表
return TableMappingEnum.getAllTableDescriptions();
}
}
/**
* 从AI响应中解析表名
*/
private List<String> parseTableNamesFromAiResponse(String aiResponse, Set<String> validTableNames) {
List<String> identifiedTables = new ArrayList<>();
if (aiResponse == null || aiResponse.trim().isEmpty()) {
return identifiedTables;
}
// 清理AI响应
String cleanedResponse = aiResponse.trim();
// 如果响应包含```标记,提取其中的内容
if (cleanedResponse.contains("```")) {
int start = cleanedResponse.indexOf("```") + 3;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
}
// 按逗号、分号、换行符分割
String[] tableNames = cleanedResponse.split("[,;\\n\\r]+");
for (String tableName : tableNames) {
String trimmedName = tableName.trim();
// 移除可能的括号内容
if (trimmedName.contains("(")) {
trimmedName = trimmedName.substring(0, trimmedName.indexOf("(")).trim();
}
// 验证表名是否有效
if (validTableNames.contains(trimmedName)) {
identifiedTables.add(trimmedName);
}
}
return identifiedTables;
}
/**
* 从AI响应中提取SQL语句
*/
private String extractSqlFromAiResponse(String aiResponse) {
if (aiResponse == null || aiResponse.trim().isEmpty()) {
throw new RuntimeException("AI未返回有效的SQL语句");
}
log.info("AI原始响应: {}", aiResponse);
// 清理AI响应提取SQL语句
String cleanedResponse = aiResponse.trim();
// 如果响应包含```sql和```标记,提取其中的内容
if (cleanedResponse.contains("```sql")) {
int start = cleanedResponse.indexOf("```sql") + 6;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
} else if (cleanedResponse.contains("```")) {
int start = cleanedResponse.indexOf("```") + 3;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
}
log.info("提取的SQL: {}", cleanedResponse);
// 验证是否为有效的SQL查询语句
if (!databaseQueryService.isQuerySql(cleanedResponse)) {
log.error("SQL验证失败: {}", cleanedResponse);
throw new RuntimeException("AI生成的SQL语句不符合安全要求: " + cleanedResponse);
}
log.info("SQL验证通过: {}", cleanedResponse);
return cleanedResponse;
}
}

View File

@@ -0,0 +1,204 @@
package com.ruoyi.oa.service.impl;
import com.ruoyi.common.helper.DataBaseHelper;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import com.ruoyi.oa.service.IDatabaseQueryService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.regex.Pattern;
/**
* 数据库查询服务实现类
*
* @author ruoyi
* @date 2024-12-19
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class DatabaseQueryServiceImpl implements IDatabaseQueryService {
private final JdbcTemplate jdbcTemplate;
// 优化后的SQL查询语句验证正则表达式
// 确保能匹配SELECT/WITH开头无论后面是空格、换行还是其他空白字符
private static final Pattern QUERY_PATTERN = Pattern.compile(
"^\\s*(SELECT|WITH)\\s+",
Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL
);
// 同时检查禁止关键字列表,确保没有包含正常查询所需的关键字
private static final List<String> FORBIDDEN_KEYWORDS = Arrays.asList(
"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE",
"CREATE", "RENAME", "GRANT", "REVOKE", "EXEC", "CALL", "MERGE"
);
@Override
public List<TableColumnVo> getTableColumns(String tableName) {
if (!DataBaseHelper.isMySql()) {
throw new UnsupportedOperationException("目前只支持MySQL数据库");
}
String sql = "select column_name,\n" +
" (case when (is_nullable = 'no' && column_key != 'PRI') then '1' else null end) as is_required,\n" +
" (case when column_key = 'PRI' then '1' else '0' end) as is_pk,\n" +
" ordinal_position as sort,\n" +
" column_comment,\n" +
" (case when extra = 'auto_increment' then '1' else '0' end) as is_increment,\n" +
" column_type,\n" +
" data_type,\n" +
" character_maximum_length,\n" +
" numeric_precision,\n" +
" numeric_scale,\n" +
" column_default,\n" +
" is_nullable\n" +
"from information_schema.columns \n" +
"where table_schema = (select database()) \n" +
"and table_name = ?\n" +
"order by ordinal_position";
return jdbcTemplate.query(sql, (rs, rowNum) -> {
TableColumnVo column = new TableColumnVo();
column.setColumnName(rs.getString("column_name"));
column.setIsRequired(rs.getString("is_required"));
column.setIsPk(rs.getString("is_pk"));
column.setSort(rs.getInt("sort"));
column.setColumnComment(rs.getString("column_comment"));
column.setIsIncrement(rs.getString("is_increment"));
column.setColumnType(rs.getString("column_type"));
column.setDataType(rs.getString("data_type"));
column.setCharacterMaximumLength(rs.getLong("character_maximum_length"));
column.setNumericPrecision(rs.getInt("numeric_precision"));
column.setNumericScale(rs.getInt("numeric_scale"));
column.setColumnDefault(rs.getString("column_default"));
column.setIsNullable(rs.getString("is_nullable"));
return column;
}, tableName);
}
@Override
public Map<String, List<TableColumnVo>> getTableColumns(List<String> tableNames) {
Map<String, List<TableColumnVo>> result = new HashMap<>();
for (String tableName : tableNames) {
result.put(tableName, getTableColumns(tableName));
}
return result;
}
@Override
public List<Map<String, Object>> executeQuery(String sql) {
// 验证SQL安全性
if (!isQuerySql(sql)) {
throw new IllegalArgumentException("只允许执行查询语句");
}
log.info("执行SQL查询: {}", sql);
return jdbcTemplate.queryForList(sql);
}
@Override
public DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta) {
// 执行查询
List<Map<String, Object>> data = executeQuery(sql);
DynamicDataVo result = new DynamicDataVo();
result.setData(data);
// 如果需要元信息,则生成字段元信息
if (includeMeta && tableName != null) {
List<TableColumnVo> columns = getTableColumns(tableName);
DynamicDataVo.Meta meta = new DynamicDataVo.Meta();
List<DynamicDataVo.Field> fields = new ArrayList<>();
for (TableColumnVo column : columns) {
DynamicDataVo.Field field = new DynamicDataVo.Field();
field.setFieldName(column.getColumnName());
field.setLabel(column.getColumnComment() != null ? column.getColumnComment() : column.getColumnName());
field.setType(getFieldType(column));
field.setFormat(getFieldFormat(column));
fields.add(field);
}
meta.setFields(fields);
result.setMeta(meta);
}
return result;
}
@Override
public boolean isQuerySql(String sql) {
if (sql == null || sql.trim().isEmpty()) {
log.warn("SQL为空");
return false;
}
// 首先尝试正则表达式匹配
if (QUERY_PATTERN.matcher(sql).matches()) {
log.debug("正则表达式匹配成功");
} else {
// 如果正则表达式匹配失败,使用简单的字符串检查作为备选
String trimmedSql = sql.trim().toUpperCase();
if (!trimmedSql.startsWith("SELECT") && !trimmedSql.startsWith("WITH")) {
log.warn("SQL不以SELECT或WITH开头: {}", sql);
return false;
}
log.debug("使用字符串检查匹配成功");
}
String upperSql = sql.toUpperCase();
// 检查是否包含禁止的关键字(使用单词边界避免部分匹配)
for (String keyword : FORBIDDEN_KEYWORDS) {
if (Pattern.compile("\\b" + keyword + "\\b").matcher(upperSql).find()) {
log.warn("SQL包含禁止关键字 {}: {}", keyword, sql);
return false;
}
}
log.debug("SQL验证通过: {}", sql);
return true;
}
/**
* 根据字段信息推断字段类型
*/
private String getFieldType(TableColumnVo column) {
String dataType = column.getDataType().toLowerCase();
if (dataType.contains("int") || dataType.contains("decimal") ||
dataType.contains("float") || dataType.contains("double")) {
return "number";
} else if (dataType.contains("date") || dataType.contains("time")) {
return "date";
} else if (dataType.contains("bool")) {
return "bool";
} else {
return "string";
}
}
/**
* 根据字段信息生成格式化规则
*/
private String getFieldFormat(TableColumnVo column) {
String dataType = column.getDataType().toLowerCase();
if (dataType.contains("int")) {
return "0,0";
} else if (dataType.contains("decimal") || dataType.contains("float") || dataType.contains("double")) {
return "0.00";
} else if (dataType.contains("date")) {
return "YYYY-MM-DD";
} else if (dataType.contains("datetime") || dataType.contains("timestamp")) {
return "YYYY-MM-DD HH:mm:ss";
}
return null;
}
}