feat(oa): 添加AI数据查询功能

- 新增AI数据查询接口和相关服务
- 实现关键词匹配和AI智能识别表功能
- 添加SQL生成和执行逻辑
- 新增动态数据返回格式和字段信息类
- 优化SQL安全性验证
This commit is contained in:
2025-08-05 11:38:17 +08:00
parent 2b30d2186f
commit 845e8cfb1e
9 changed files with 1077 additions and 0 deletions

View File

@@ -26,8 +26,10 @@ import com.ruoyi.oa.domain.bo.SysOaAiMessageBo;
import com.ruoyi.oa.service.ISysOaAiConversationService;
import com.ruoyi.oa.service.ISysOaAiMessageService;
import com.ruoyi.oa.service.ISysOaAiConfigService;
import com.ruoyi.oa.service.IAiDataQueryService;
import com.ruoyi.oa.utils.AiServiceUtil;
import com.ruoyi.common.core.page.TableDataInfo;
import lombok.extern.slf4j.Slf4j;
/**
* AI对话管理
@@ -39,12 +41,14 @@ import com.ruoyi.common.core.page.TableDataInfo;
@RequiredArgsConstructor
@RestController
@RequestMapping("/oa/ai")
@Slf4j
public class SysOaAiController extends BaseController {
private final ISysOaAiConversationService conversationService;
private final ISysOaAiMessageService messageService;
private final ISysOaAiConfigService configService;
private final AiServiceUtil aiServiceUtil;
private final IAiDataQueryService aiDataQueryService;
/**
* 查询AI对话历史列表
@@ -275,4 +279,39 @@ public class SysOaAiController extends BaseController {
private String callAiService(String message) {
return aiServiceUtil.callDeepSeek(message);
}
/**
* AI数据查询接口
* 根据用户自然语言描述查询数据库数据
*/
@PostMapping("/data-query")
public R<com.ruoyi.oa.domain.vo.DynamicDataVo> queryDataByAi(@Validated @RequestBody com.ruoyi.oa.domain.bo.OaAiDataQueryBo queryBo) {
try {
com.ruoyi.oa.domain.vo.DynamicDataVo result = aiDataQueryService.queryDataByAi(queryBo);
return R.ok(result);
} catch (Exception e) {
log.error("AI数据查询失败", e);
return R.fail("AI数据查询失败: " + e.getMessage());
}
}
/**
* 根据关键词匹配相关表
*/
@PostMapping("/match-tables")
public R<java.util.List<String>> matchTablesByKeywords(@RequestBody java.util.Map<String, String[]> request) {
try {
String[] keywords = request.get("keywords");
if (keywords == null || keywords.length == 0) {
return R.fail("关键词不能为空");
}
java.util.List<String> matchedTables = aiDataQueryService.matchTablesByKeywords(keywords);
return R.ok(matchedTables);
} catch (Exception e) {
log.error("表匹配失败", e);
return R.fail("表匹配失败: " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,31 @@
package com.ruoyi.oa.domain.bo;
import lombok.Data;
import javax.validation.constraints.NotBlank;
/**
* AI数据查询请求
*
* @author ruoyi
* @date 2024-12-19
*/
@Data
public class OaAiDataQueryBo {
/**
* 用户查询需求描述
*/
@NotBlank(message = "查询需求不能为空")
private String query;
/**
* 最大返回记录数默认100
*/
private Integer limit = 100;
/**
* 是否包含字段元信息默认true
*/
private Boolean includeMeta = true;
}

View File

@@ -0,0 +1,63 @@
package com.ruoyi.oa.domain.vo;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* 动态数据返回格式
*
* @author ruoyi
* @date 2024-12-19
*/
@Data
public class DynamicDataVo {
/**
* 元数据信息
*/
private Meta meta;
/**
* 实际数据
*/
private List<Map<String, Object>> data;
/**
* 元数据类
*/
@Data
public static class Meta {
/**
* 字段信息列表
*/
private List<Field> fields;
}
/**
* 字段信息类
*/
@Data
public static class Field {
/**
* 字段名后端返回的key
*/
private String fieldName;
/**
* 字段中文名称(用于前端显示列名)
*/
private String label;
/**
* 数据类型number/string/date/bool等前端据此处理渲染
*/
private String type;
/**
* 格式化规则(如数字千分位、日期格式等,可选)
*/
private String format;
}
}

View File

@@ -0,0 +1,78 @@
package com.ruoyi.oa.domain.vo;
import lombok.Data;
/**
* 数据库表字段信息
*
* @author ruoyi
* @date 2024-12-19
*/
@Data
public class TableColumnVo {
/**
* 字段名
*/
private String columnName;
/**
* 是否必填
*/
private String isRequired;
/**
* 是否主键
*/
private String isPk;
/**
* 排序
*/
private Integer sort;
/**
* 字段注释
*/
private String columnComment;
/**
* 是否自增
*/
private String isIncrement;
/**
* 字段类型
*/
private String columnType;
/**
* 数据类型
*/
private String dataType;
/**
* 字符最大长度
*/
private Long characterMaximumLength;
/**
* 数值精度
*/
private Integer numericPrecision;
/**
* 数值小数位数
*/
private Integer numericScale;
/**
* 默认值
*/
private String columnDefault;
/**
* 是否可为空
*/
private String isNullable;
}

View File

@@ -0,0 +1,214 @@
package com.ruoyi.oa.enums;
import lombok.Getter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 数据库表名映射枚举
* 用于将中文描述映射到对应的数据库表名
*
* @author ruoyi
* @date 2024-12-19
*/
@Getter
public enum TableMappingEnum {
// 员工相关表
EMPLOYEE_FILES("employee_files", "员工档案表", new String[]{"员工文件", "档案"}),
EMPLOYEE_OFFBOARDING("employee_offboarding", "员工离职表", new String[]{"离职", "员工离职"}),
EMPLOYEE_ONBOARDING("employee_onboarding", "员工入职表", new String[]{"入职", "员工入职"}),
// 文章相关表
EXPORT_ARTICLE("export_article", "文章表", new String[]{"文章", "内容"}),
EXPORT_ARTICLE_CATEGORY("export_article_category", "文章分类表", new String[]{"文章分类", "分类"}),
EXPORT_CAROUSEL("export_carousel", "轮播图表", new String[]{"轮播图", "轮播"}),
EXPORT_CATEGORY("export_category", "分类表", new String[]{"分类", "类别"}),
EXPORT_CONTACT("export_contact", "联系方式表", new String[]{"联系方式", "联系"}),
EXPORT_ITEM("export_item", "展示品表", new String[]{"展示品", "展示"}),
EXPORT_LANGUAGE("export_language", "语言管理表", new String[]{"语言", "多语言"}),
// OA应用相关表
OA_APPLICATION("oa_application", "应用集成表", new String[]{"应用", "集成"}),
OA_ATTENDANCE_RECORD("oa_attendance_record", "考勤记录表", new String[]{"考勤", "打卡", "出勤"}),
OA_BINDING_ITEM_DETAIL("oa_binding_item_detail", "绑定记录明细表", new String[]{"绑定", "明细"}),
OA_BUSINESS("oa_business", "CRM商机表", new String[]{"商机", "业务", "CRM"}),
OA_BUSINESS_PRODUCT("oa_business_product", "CRM商机产品关联表", new String[]{"商机产品", "产品关联"}),
OA_CLUE("oa_clue", "CRM线索表", new String[]{"线索", "客户线索"}),
OA_CUSTOMER("oa_customer", "CRM客户表", new String[]{"客户", "客户信息"}),
OA_EMAIL_ACCOUNT("oa_email_account", "发件人邮箱账号管理", new String[]{"邮箱", "邮件账号"}),
OA_EMAIL_TEMPLATE("oa_email_template", "邮件模板表", new String[]{"邮件模板", "模板"}),
OA_EMPLOYEE("oa_employee", "员工基础信息", new String[]{"员工", "员工信息"}),
OA_EMPLOYEE_TEMPLATE_BINDING("oa_employee_template_binding", "员工模板绑定及月度发放记录表", new String[]{"员工模板", "绑定"}),
OA_EXPRESS("oa_express", "快递表", new String[]{"快递", "物流"}),
OA_EXPRESS_QUESTION("oa_express_question", "快递问题表", new String[]{"快递问题", "问题"}),
OA_FEEDBACK("oa_feedback", "反馈表", new String[]{"反馈", "意见"}),
OA_FEEDBACK_ITEM("oa_feedback_item", "反馈项目表", new String[]{"反馈项目", "反馈详情"}),
OA_FOLLOW_UP_RECORD("oa_follow_up_record", "CRM跟进记录", new String[]{"跟进", "跟进记录"}),
OA_FURNITURE_TABLE("oa_furniture_table", "存储家具相关业务数据", new String[]{"家具", "家具数据"}),
OA_INSURANCE_TEMPLATE("oa_insurance_template", "社保公积金模板主表", new String[]{"社保", "公积金", "保险"}),
OA_INSURANCE_TEMPLATE_DETAIL("oa_insurance_template_detail", "社保公积金模板明细表", new String[]{"社保明细", "公积金明细"}),
OA_PAYMENT_PROGRESS("oa_payment_progress", "项目付款进度表", new String[]{"付款", "付款进度"}),
OA_PRODUCT("oa_product", "CRM产品表", new String[]{"产品", "产品信息"}),
OA_PROGRESS("oa_progress", "项目进度主表", new String[]{"项目进度", "进度"}),
OA_PROGRESS_DETAIL("oa_progress_detail", "项目进度付款进度扩展表", new String[]{"进度详情", "付款详情"}),
OA_PROJECT_REPORT("oa_project_report", "报工记录表", new String[]{"报工", "报工记录"}),
OA_PROJECT_SCHEDULE("oa_project_schedule", "项目进度主表", new String[]{"项目排期", "排期"}),
OA_PROJECT_SCHEDULE_STEP("oa_project_schedule_step", "项目进度步骤跟踪表", new String[]{"项目步骤", "步骤跟踪"}),
OA_REPORT_DETAIL("oa_report_detail", "设计项目汇报详情表", new String[]{"汇报详情", "项目汇报"}),
OA_REPORT_SCHEDULE("oa_report_schedule", "项目排产表", new String[]{"排产", "生产排期"}),
OA_REPORT_SUMMARY("oa_report_summary", "设计项目汇报概述表", new String[]{"汇报概述", "项目概述"}),
OA_REQUIREMENTS("oa_requirements", "OA需求表", new String[]{"需求", "OA需求"}),
OA_SALARY("oa_salary", "薪水表", new String[]{"薪水", "工资", "薪资"}),
OA_SALARY_ITEM("oa_salary_item", "薪水详情表", new String[]{"薪水详情", "工资详情"}),
OA_SALARY_TEMPLATE("oa_salary_template", "薪资模板主表", new String[]{"薪资模板", "工资模板"}),
OA_SALARY_TEMPLATE_DETAIL("oa_salary_template_detail", "薪资模板明细表", new String[]{"薪资模板明细", "工资模板明细"}),
OA_SCHEDULE_TEMPLATE("oa_schedule_template", "进度模板主表", new String[]{"进度模板", "排期模板"}),
OA_SCHEDULE_TEMPLATE_STEP("oa_schedule_template_step", "进度模板步骤表", new String[]{"进度模板步骤", "模板步骤"}),
// 通信相关表
SOCKET_CONTACT("socket_contact", "通信目录表", new String[]{"通信", "联系人"}),
SOCKET_MESSAGE("socket_message", "对话信息表", new String[]{"对话", "消息"}),
// 系统相关表
SYS_CONFIG("sys_config", "参数配置表", new String[]{"配置", "系统配置"}),
SYS_DEPT("sys_dept", "部门表", new String[]{"部门", "组织架构"}),
SYS_DICT_DATA("sys_dict_data", "字典数据表", new String[]{"字典数据", "数据字典"}),
SYS_DICT_TYPE("sys_dict_type", "字典类型表", new String[]{"字典类型", "字典"}),
SYS_LOGININFOR("sys_logininfor", "系统访问记录", new String[]{"登录记录", "访问记录"}),
SYS_MENU("sys_menu", "菜单权限表", new String[]{"菜单", "权限"}),
SYS_NOTICE("sys_notice", "通知公告表", new String[]{"通知", "公告"}),
SYS_OA_AI_CONFIG("sys_oa_ai_config", "AI配置表", new String[]{"AI配置", "人工智能配置"}),
SYS_OA_AI_CONVERSATION("sys_oa_ai_conversation", "AI对话历史表", new String[]{"AI对话", "对话历史"}),
SYS_OA_AI_MESSAGE("sys_oa_ai_message", "AI对话详情表", new String[]{"AI消息", "对话详情"}),
SYS_OA_ARTICLE("sys_oa_article", "知识库表", new String[]{"知识库", "知识"}),
SYS_OA_ATTENDANCE("sys_oa_attendance", "考勤表", new String[]{"考勤", "出勤"}),
SYS_OA_BID("sys_oa_bid", "投标管理表", new String[]{"投标", "招标"}),
SYS_OA_CATEGORY("sys_oa_category", "文章分类表", new String[]{"文章分类", "分类管理"}),
SYS_OA_CLAIM("sys_oa_claim", "报销表", new String[]{"报销", "费用报销"}),
SYS_OA_CLAIM_DETAIL("sys_oa_claim_detail", "报销明细表", new String[]{"报销明细", "费用明细"}),
SYS_OA_CONTRACT("sys_oa_contract", "合同表", new String[]{"合同", "合同管理"}),
SYS_OA_COST("sys_oa_cost", "成本表", new String[]{"成本", "费用"}),
SYS_OA_DETAIL("sys_oa_detail", "进出账明细表", new String[]{"进出账明细", "账目明细"}),
SYS_OA_FINANCE("sys_oa_finance", "进出账主表", new String[]{"进出账", "财务"}),
SYS_OA_HOLIDAY("sys_oa_holiday", "节假日表", new String[]{"节假日", "假期"}),
SYS_OA_PROJECT("sys_oa_project", "项目管理表", new String[]{"项目", "项目管理"}),
SYS_OA_PURPOSE("sys_oa_purpose", "采购意向表", new String[]{"采购意向", "采购"}),
SYS_OA_RECEIVE_ACCOUNT("sys_oa_receive_account", "收款账户表", new String[]{"收款账户", "账户"}),
SYS_OA_REMIND("sys_oa_remind", "任务事件提醒表", new String[]{"提醒", "任务提醒"}),
SYS_OA_TASK("sys_oa_task", "任务管理表", new String[]{"任务", "任务管理"}),
SYS_OA_TASK_ITEM("sys_oa_task_item", "报工任务单元", new String[]{"任务单元", "报工单元"}),
SYS_OA_TASK_USER("sys_oa_task_user", "任务用户表", new String[]{"任务用户", "用户任务"}),
SYS_OA_WAREHOUSE("sys_oa_warehouse", "仓库表", new String[]{"仓库", "库存"}),
SYS_OA_WAREHOUSE_DETAIL("sys_oa_warehouse_detail", "仓库明细表", new String[]{"仓库明细", "库存明细"}),
SYS_OA_WAREHOUSE_LOG("sys_oa_warehouse_log", "仓库日志表", new String[]{"仓库日志", "库存日志"}),
SYS_OA_WAREHOUSE_MASTER("sys_oa_warehouse_master", "仓库主表", new String[]{"仓库主表", "库存主表"}),
SYS_OA_WAREHOUSE_TASK("sys_oa_warehouse_task", "采购计划表", new String[]{"采购计划", "采购任务"}),
SYS_OA_WORK("sys_oa_work", "工作表", new String[]{"工作", "工作记录"}),
SYS_OPER_LOG("sys_oper_log", "操作日志记录", new String[]{"操作日志", "日志"}),
SYS_OSS("sys_oss", "OSS对象存储表", new String[]{"文件存储", "对象存储"}),
SYS_OSS_ACL("sys_oss_acl", "OSS文件授权表", new String[]{"文件授权", "存储授权"}),
SYS_OSS_CONFIG("sys_oss_config", "对象存储配置表", new String[]{"存储配置", "OSS配置"}),
SYS_POST("sys_post", "岗位信息表", new String[]{"岗位", "职位"}),
SYS_PREFIX_COUNTER("sys_prefix_counter", "前缀计数器表", new String[]{"计数器", "前缀"}),
SYS_ROLE("sys_role", "角色信息表", new String[]{"角色", "权限角色"}),
SYS_ROLE_DEPT("sys_role_dept", "角色和部门关联表", new String[]{"角色部门", "部门角色"}),
SYS_ROLE_MENU("sys_role_menu", "角色和菜单关联表", new String[]{"角色菜单", "菜单角色"}),
SYS_USER("sys_user", "用户信息表", new String[]{"用户", "用户信息"}),
SYS_USER_POST("sys_user_post", "用户与岗位关联表", new String[]{"用户岗位", "岗位用户"}),
SYS_USER_ROLE("sys_user_role", "用户和角色关联表", new String[]{"用户角色", "角色用户"}),
// 工作流相关表
WF_CATEGORY("wf_category", "流程分类表", new String[]{"流程分类", "工作流分类"}),
WF_COPY("wf_copy", "流程抄送表", new String[]{"流程抄送", "抄送"}),
WF_DEPLOY_FORM("wf_deploy_form", "流程实例关联表单", new String[]{"流程表单", "部署表单"}),
WF_FORM("wf_form", "流程表单信息表", new String[]{"流程表单", "表单信息"});
private final String tableName;
private final String description;
private final String[] keywords;
TableMappingEnum(String tableName, String description, String[] keywords) {
this.tableName = tableName;
this.description = description;
this.keywords = keywords;
}
/**
* 根据关键词匹配表名
*
* @param keyword 关键词
* @return 匹配的表名列表
*/
public static List<String> getTableNamesByKeyword(String keyword) {
if (keyword == null || keyword.trim().isEmpty()) {
return new java.util.ArrayList<>();
}
String lowerKeyword = keyword.toLowerCase();
return Arrays.stream(values())
.filter(table -> {
// 检查表名是否包含关键词
if (table.getTableName().toLowerCase().contains(lowerKeyword)) {
return true;
}
// 检查描述是否包含关键词
if (table.getDescription().toLowerCase().contains(lowerKeyword)) {
return true;
}
// 检查关键词数组是否包含关键词
return Arrays.stream(table.getKeywords())
.anyMatch(kw -> kw.toLowerCase().contains(lowerKeyword));
})
.map(TableMappingEnum::getTableName)
.collect(Collectors.toList());
}
/**
* 根据多个关键词匹配表名
*
* @param keywords 关键词数组
* @return 匹配的表名列表
*/
public static List<String> getTableNamesByKeywords(String[] keywords) {
if (keywords == null || keywords.length == 0) {
return new java.util.ArrayList<>();
}
java.util.Set<String> matchedTables = new java.util.HashSet<>();
for (String keyword : keywords) {
matchedTables.addAll(getTableNamesByKeyword(keyword));
}
return new java.util.ArrayList<>(matchedTables);
}
/**
* 获取所有表名和描述的映射
*
* @return 表名到描述的映射
*/
public static Map<String, String> getAllTableDescriptions() {
return Arrays.stream(values())
.collect(Collectors.toMap(
TableMappingEnum::getTableName,
TableMappingEnum::getDescription
));
}
/**
* 根据表名获取描述
*
* @param tableName 表名
* @return 描述
*/
public static String getDescriptionByTableName(String tableName) {
return Arrays.stream(values())
.filter(table -> table.getTableName().equals(tableName))
.findFirst()
.map(TableMappingEnum::getDescription)
.orElse("");
}
}

View File

@@ -0,0 +1,40 @@
package com.ruoyi.oa.service;
import com.ruoyi.oa.domain.bo.OaAiDataQueryBo;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
/**
* AI数据查询服务接口
*
* @author ruoyi
* @date 2024-12-19
*/
public interface IAiDataQueryService {
/**
* 根据用户需求进行AI数据查询
*
* @param queryBo 查询请求
* @return 动态数据格式的查询结果
*/
DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo);
/**
* 根据关键词匹配相关表
*
* @param keywords 关键词
* @return 匹配的表名列表
*/
java.util.List<String> matchTablesByKeywords(String[] keywords);
/**
* 生成AI提示词
*
* @param userQuery 用户查询需求
* @param tableNames 相关表名
* @param tableColumns 表字段信息
* @return AI提示词
*/
String generateAiPrompt(String userQuery, java.util.List<String> tableNames,
java.util.Map<String, java.util.List<com.ruoyi.oa.domain.vo.TableColumnVo>> tableColumns);
}

View File

@@ -0,0 +1,58 @@
package com.ruoyi.oa.service;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import java.util.List;
import java.util.Map;
/**
* 数据库查询服务接口
*
* @author ruoyi
* @date 2024-12-19
*/
public interface IDatabaseQueryService {
/**
* 根据表名获取表字段信息
*
* @param tableName 表名
* @return 字段信息列表
*/
List<TableColumnVo> getTableColumns(String tableName);
/**
* 根据多个表名获取表字段信息
*
* @param tableNames 表名列表
* @return 表名到字段信息的映射
*/
Map<String, List<TableColumnVo>> getTableColumns(List<String> tableNames);
/**
* 执行SQL查询
*
* @param sql SQL语句
* @return 查询结果
*/
List<Map<String, Object>> executeQuery(String sql);
/**
* 执行SQL查询并返回动态数据格式
*
* @param sql SQL语句
* @param tableName 表名(用于生成字段元信息)
* @param includeMeta 是否包含元信息
* @return 动态数据格式
*/
DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta);
/**
* 验证SQL是否为查询语句
*
* @param sql SQL语句
* @return 是否为查询语句
*/
boolean isQuerySql(String sql);
}

View File

@@ -0,0 +1,350 @@
package com.ruoyi.oa.service.impl;
import com.ruoyi.oa.domain.bo.OaAiDataQueryBo;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import com.ruoyi.oa.enums.TableMappingEnum;
import com.ruoyi.oa.service.IAiDataQueryService;
import com.ruoyi.oa.service.IDatabaseQueryService;
import com.ruoyi.oa.utils.AiServiceUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* AI数据查询服务实现类
*
* @author ruoyi
* @date 2024-12-19
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class AiDataQueryServiceImpl implements IAiDataQueryService {
private final IDatabaseQueryService databaseQueryService;
private final AiServiceUtil aiServiceUtil;
@Override
public DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo) {
try {
// 1. 从用户查询中提取关键词
String[] keywords = extractKeywords(queryBo.getQuery());
// 2. 根据关键词匹配相关表
List<String> matchedTables = matchTablesByKeywords(keywords);
// 3. 如果关键词匹配失败使用AI智能识别表
if (matchedTables.isEmpty()) {
log.info("关键词匹配失败使用AI智能识别表");
matchedTables = identifyTablesByAi(queryBo.getQuery());
if (matchedTables.isEmpty()) {
throw new RuntimeException("未找到与查询需求相关的数据表");
}
}
// 4. 获取匹配表的字段信息
Map<String, List<TableColumnVo>> tableColumns = databaseQueryService.getTableColumns(matchedTables);
// 5. 生成AI提示词
String prompt = generateAiPrompt(queryBo.getQuery(), matchedTables, tableColumns);
// 6. 调用AI生成SQL
String aiResponse = aiServiceUtil.callDeepSeek(prompt);
// 7. 提取SQL语句
String sql = extractSqlFromAiResponse(aiResponse);
// 8. 执行SQL查询
String primaryTable = matchedTables.get(0); // 使用第一个匹配的表作为主表
DynamicDataVo result = databaseQueryService.executeQueryWithMeta(sql, primaryTable, queryBo.getIncludeMeta());
// 9. 限制返回记录数
if (result.getData() != null && result.getData().size() > queryBo.getLimit()) {
result.setData(result.getData().subList(0, queryBo.getLimit()));
}
return result;
} catch (Exception e) {
log.error("AI数据查询失败", e);
throw new RuntimeException("AI数据查询失败: " + e.getMessage());
}
}
@Override
public List<String> matchTablesByKeywords(String[] keywords) {
return TableMappingEnum.getTableNamesByKeywords(keywords);
}
@Override
public String generateAiPrompt(String userQuery, List<String> tableNames,
Map<String, List<TableColumnVo>> tableColumns) {
StringBuilder prompt = new StringBuilder();
// 系统角色设定
prompt.append("你是一个专业的数据库查询专家请根据用户的需求和提供的数据库表结构信息生成相应的SQL查询语句。\n\n");
// 重要规则
prompt.append("重要规则:\n");
prompt.append("1. 只能生成SELECT查询语句严禁生成INSERT、UPDATE、DELETE、DROP、CREATE、ALTER等修改数据的语句\n");
prompt.append("2. 生成的SQL必须语法正确可以直接执行\n");
prompt.append("3. 如果涉及多表查询请使用适当的JOIN语句\n");
prompt.append("4. 请根据用户需求添加适当的WHERE条件、ORDER BY、LIMIT等子句\n");
prompt.append("5. 只返回SQL语句不要包含其他解释文字\n\n");
// 用户需求
prompt.append("用户需求:").append(userQuery).append("\n\n");
// 相关表信息
prompt.append("相关数据表信息:\n");
for (String tableName : tableNames) {
prompt.append("表名:").append(tableName).append("\n");
prompt.append("表描述:").append(TableMappingEnum.getDescriptionByTableName(tableName)).append("\n");
List<TableColumnVo> columns = tableColumns.get(tableName);
if (columns != null) {
prompt.append("字段信息:\n");
for (TableColumnVo column : columns) {
prompt.append(" - ").append(column.getColumnName())
.append(" (").append(column.getColumnType()).append(")")
.append(" - ").append(column.getColumnComment() != null ? column.getColumnComment() : "无注释");
if ("1".equals(column.getIsPk())) {
prompt.append(" [主键]");
}
if ("1".equals(column.getIsRequired())) {
prompt.append(" [必填]");
}
prompt.append("\n");
}
}
prompt.append("\n");
}
// 输出要求
prompt.append("请根据以上信息生成SQL查询语句只返回SQL语句本身不要包含任何其他文字。");
return prompt.toString();
}
/**
* 从用户查询中提取关键词
*/
private String[] extractKeywords(String query) {
if (query == null || query.trim().isEmpty()) {
return new String[0];
}
// 移除常见的查询词汇
String cleanedQuery = query.replaceAll("查询|获取|显示|查看|搜索|查找|统计|分析|的|表|情况|信息|数据", "");
// 按空格、标点符号分割
String[] words = cleanedQuery.split("[\\s,,。!?!?]+");
// 过滤掉空字符串和太短的词
List<String> keywords = new ArrayList<>();
for (String word : words) {
if (word != null && word.trim().length() >= 2) {
keywords.add(word.trim());
}
}
// 如果没有提取到关键词,尝试更简单的分割方式
if (keywords.isEmpty()) {
// 直接按字符分割,提取有意义的词组
String[] simpleWords = cleanedQuery.split("");
StringBuilder currentWord = new StringBuilder();
for (String charStr : simpleWords) {
if (charStr.matches("[\\u4e00-\\u9fa5]")) { // 中文字符
currentWord.append(charStr);
} else {
if (currentWord.length() >= 2) {
keywords.add(currentWord.toString());
}
currentWord.setLength(0);
}
}
// 处理最后一个词
if (currentWord.length() >= 2) {
keywords.add(currentWord.toString());
}
}
return keywords.toArray(new String[0]);
}
/**
* 使用AI智能识别相关的数据表
*/
private List<String> identifyTablesByAi(String userQuery) {
try {
// 从数据库中获取所有可用的表信息
Map<String, String> allTableDescriptions = getAvailableTablesFromDatabase();
if (allTableDescriptions.isEmpty()) {
log.warn("未从数据库获取到任何表信息");
return new ArrayList<>();
}
// 构建AI提示词
StringBuilder prompt = new StringBuilder();
prompt.append("你是一个数据库专家,请根据用户的查询需求,从以下数据库表中选择最相关的表名。\n\n");
prompt.append("用户查询:").append(userQuery).append("\n\n");
prompt.append("可用的数据库表:\n");
for (Map.Entry<String, String> entry : allTableDescriptions.entrySet()) {
prompt.append("- ").append(entry.getKey()).append(" (").append(entry.getValue()).append(")\n");
}
prompt.append("\n请根据用户查询需求返回最相关的表名用逗号分隔最多返回3个表");
prompt.append("只返回表名,不要包含其他解释文字。");
// 调用AI获取表名
String aiResponse = aiServiceUtil.callDeepSeek(prompt.toString());
// 解析AI返回的表名
List<String> identifiedTables = parseTableNamesFromAiResponse(aiResponse, allTableDescriptions.keySet());
log.info("AI识别的表{}", identifiedTables);
return identifiedTables;
} catch (Exception e) {
log.error("AI表识别失败", e);
return new ArrayList<>();
}
}
/**
* 从数据库中获取所有可用的表信息
*/
private Map<String, String> getAvailableTablesFromDatabase() {
try {
String sql = "select table_name, table_comment from information_schema.tables " +
"where table_schema = (select database()) " +
"AND table_name NOT LIKE 'ACT_%' " +
"AND table_name NOT LIKE 'xxl_job_%' " +
"AND table_name NOT LIKE 'FLW_%' " +
"AND table_name NOT LIKE 'gen_%' " +
"order by table_name";
List<Map<String, Object>> results = databaseQueryService.executeQuery(sql);
Map<String, String> tableDescriptions = new HashMap<>();
for (Map<String, Object> row : results) {
String tableName = (String) row.get("table_name");
String tableComment = (String) row.get("table_comment");
if (tableName != null) {
// 如果表注释为空,使用枚举中的描述作为备选
String description = (tableComment != null && !tableComment.trim().isEmpty())
? tableComment
: TableMappingEnum.getDescriptionByTableName(tableName);
// 如果枚举中也没有描述,使用表名作为描述
if (description == null || description.trim().isEmpty()) {
description = tableName;
}
tableDescriptions.put(tableName, description);
}
}
log.info("从数据库获取到 {} 个表", tableDescriptions.size());
return tableDescriptions;
} catch (Exception e) {
log.error("从数据库获取表信息失败", e);
// 如果数据库查询失败,回退到枚举中的表
return TableMappingEnum.getAllTableDescriptions();
}
}
/**
* 从AI响应中解析表名
*/
private List<String> parseTableNamesFromAiResponse(String aiResponse, Set<String> validTableNames) {
List<String> identifiedTables = new ArrayList<>();
if (aiResponse == null || aiResponse.trim().isEmpty()) {
return identifiedTables;
}
// 清理AI响应
String cleanedResponse = aiResponse.trim();
// 如果响应包含```标记,提取其中的内容
if (cleanedResponse.contains("```")) {
int start = cleanedResponse.indexOf("```") + 3;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
}
// 按逗号、分号、换行符分割
String[] tableNames = cleanedResponse.split("[,;\\n\\r]+");
for (String tableName : tableNames) {
String trimmedName = tableName.trim();
// 移除可能的括号内容
if (trimmedName.contains("(")) {
trimmedName = trimmedName.substring(0, trimmedName.indexOf("(")).trim();
}
// 验证表名是否有效
if (validTableNames.contains(trimmedName)) {
identifiedTables.add(trimmedName);
}
}
return identifiedTables;
}
/**
* 从AI响应中提取SQL语句
*/
private String extractSqlFromAiResponse(String aiResponse) {
if (aiResponse == null || aiResponse.trim().isEmpty()) {
throw new RuntimeException("AI未返回有效的SQL语句");
}
log.info("AI原始响应: {}", aiResponse);
// 清理AI响应提取SQL语句
String cleanedResponse = aiResponse.trim();
// 如果响应包含```sql和```标记,提取其中的内容
if (cleanedResponse.contains("```sql")) {
int start = cleanedResponse.indexOf("```sql") + 6;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
} else if (cleanedResponse.contains("```")) {
int start = cleanedResponse.indexOf("```") + 3;
int end = cleanedResponse.indexOf("```", start);
if (end > start) {
cleanedResponse = cleanedResponse.substring(start, end).trim();
}
}
log.info("提取的SQL: {}", cleanedResponse);
// 验证是否为有效的SQL查询语句
if (!databaseQueryService.isQuerySql(cleanedResponse)) {
log.error("SQL验证失败: {}", cleanedResponse);
throw new RuntimeException("AI生成的SQL语句不符合安全要求: " + cleanedResponse);
}
log.info("SQL验证通过: {}", cleanedResponse);
return cleanedResponse;
}
}

View File

@@ -0,0 +1,204 @@
package com.ruoyi.oa.service.impl;
import com.ruoyi.common.helper.DataBaseHelper;
import com.ruoyi.oa.domain.vo.DynamicDataVo;
import com.ruoyi.oa.domain.vo.TableColumnVo;
import com.ruoyi.oa.service.IDatabaseQueryService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.regex.Pattern;
/**
* 数据库查询服务实现类
*
* @author ruoyi
* @date 2024-12-19
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class DatabaseQueryServiceImpl implements IDatabaseQueryService {
private final JdbcTemplate jdbcTemplate;
// 优化后的SQL查询语句验证正则表达式
// 确保能匹配SELECT/WITH开头无论后面是空格、换行还是其他空白字符
private static final Pattern QUERY_PATTERN = Pattern.compile(
"^\\s*(SELECT|WITH)\\s+",
Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL
);
// 同时检查禁止关键字列表,确保没有包含正常查询所需的关键字
private static final List<String> FORBIDDEN_KEYWORDS = Arrays.asList(
"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE",
"CREATE", "RENAME", "GRANT", "REVOKE", "EXEC", "CALL", "MERGE"
);
@Override
public List<TableColumnVo> getTableColumns(String tableName) {
if (!DataBaseHelper.isMySql()) {
throw new UnsupportedOperationException("目前只支持MySQL数据库");
}
String sql = "select column_name,\n" +
" (case when (is_nullable = 'no' && column_key != 'PRI') then '1' else null end) as is_required,\n" +
" (case when column_key = 'PRI' then '1' else '0' end) as is_pk,\n" +
" ordinal_position as sort,\n" +
" column_comment,\n" +
" (case when extra = 'auto_increment' then '1' else '0' end) as is_increment,\n" +
" column_type,\n" +
" data_type,\n" +
" character_maximum_length,\n" +
" numeric_precision,\n" +
" numeric_scale,\n" +
" column_default,\n" +
" is_nullable\n" +
"from information_schema.columns \n" +
"where table_schema = (select database()) \n" +
"and table_name = ?\n" +
"order by ordinal_position";
return jdbcTemplate.query(sql, (rs, rowNum) -> {
TableColumnVo column = new TableColumnVo();
column.setColumnName(rs.getString("column_name"));
column.setIsRequired(rs.getString("is_required"));
column.setIsPk(rs.getString("is_pk"));
column.setSort(rs.getInt("sort"));
column.setColumnComment(rs.getString("column_comment"));
column.setIsIncrement(rs.getString("is_increment"));
column.setColumnType(rs.getString("column_type"));
column.setDataType(rs.getString("data_type"));
column.setCharacterMaximumLength(rs.getLong("character_maximum_length"));
column.setNumericPrecision(rs.getInt("numeric_precision"));
column.setNumericScale(rs.getInt("numeric_scale"));
column.setColumnDefault(rs.getString("column_default"));
column.setIsNullable(rs.getString("is_nullable"));
return column;
}, tableName);
}
@Override
public Map<String, List<TableColumnVo>> getTableColumns(List<String> tableNames) {
Map<String, List<TableColumnVo>> result = new HashMap<>();
for (String tableName : tableNames) {
result.put(tableName, getTableColumns(tableName));
}
return result;
}
@Override
public List<Map<String, Object>> executeQuery(String sql) {
// 验证SQL安全性
if (!isQuerySql(sql)) {
throw new IllegalArgumentException("只允许执行查询语句");
}
log.info("执行SQL查询: {}", sql);
return jdbcTemplate.queryForList(sql);
}
@Override
public DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta) {
// 执行查询
List<Map<String, Object>> data = executeQuery(sql);
DynamicDataVo result = new DynamicDataVo();
result.setData(data);
// 如果需要元信息,则生成字段元信息
if (includeMeta && tableName != null) {
List<TableColumnVo> columns = getTableColumns(tableName);
DynamicDataVo.Meta meta = new DynamicDataVo.Meta();
List<DynamicDataVo.Field> fields = new ArrayList<>();
for (TableColumnVo column : columns) {
DynamicDataVo.Field field = new DynamicDataVo.Field();
field.setFieldName(column.getColumnName());
field.setLabel(column.getColumnComment() != null ? column.getColumnComment() : column.getColumnName());
field.setType(getFieldType(column));
field.setFormat(getFieldFormat(column));
fields.add(field);
}
meta.setFields(fields);
result.setMeta(meta);
}
return result;
}
@Override
public boolean isQuerySql(String sql) {
if (sql == null || sql.trim().isEmpty()) {
log.warn("SQL为空");
return false;
}
// 首先尝试正则表达式匹配
if (QUERY_PATTERN.matcher(sql).matches()) {
log.debug("正则表达式匹配成功");
} else {
// 如果正则表达式匹配失败,使用简单的字符串检查作为备选
String trimmedSql = sql.trim().toUpperCase();
if (!trimmedSql.startsWith("SELECT") && !trimmedSql.startsWith("WITH")) {
log.warn("SQL不以SELECT或WITH开头: {}", sql);
return false;
}
log.debug("使用字符串检查匹配成功");
}
String upperSql = sql.toUpperCase();
// 检查是否包含禁止的关键字(使用单词边界避免部分匹配)
for (String keyword : FORBIDDEN_KEYWORDS) {
if (Pattern.compile("\\b" + keyword + "\\b").matcher(upperSql).find()) {
log.warn("SQL包含禁止关键字 {}: {}", keyword, sql);
return false;
}
}
log.debug("SQL验证通过: {}", sql);
return true;
}
/**
* 根据字段信息推断字段类型
*/
private String getFieldType(TableColumnVo column) {
String dataType = column.getDataType().toLowerCase();
if (dataType.contains("int") || dataType.contains("decimal") ||
dataType.contains("float") || dataType.contains("double")) {
return "number";
} else if (dataType.contains("date") || dataType.contains("time")) {
return "date";
} else if (dataType.contains("bool")) {
return "bool";
} else {
return "string";
}
}
/**
* 根据字段信息生成格式化规则
*/
private String getFieldFormat(TableColumnVo column) {
String dataType = column.getDataType().toLowerCase();
if (dataType.contains("int")) {
return "0,0";
} else if (dataType.contains("decimal") || dataType.contains("float") || dataType.contains("double")) {
return "0.00";
} else if (dataType.contains("date")) {
return "YYYY-MM-DD";
} else if (dataType.contains("datetime") || dataType.contains("timestamp")) {
return "YYYY-MM-DD HH:mm:ss";
}
return null;
}
}