diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/controller/SysOaAiController.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/controller/SysOaAiController.java index 6650af4..94de943 100644 --- a/ruoyi-oa/src/main/java/com/ruoyi/oa/controller/SysOaAiController.java +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/controller/SysOaAiController.java @@ -26,8 +26,10 @@ import com.ruoyi.oa.domain.bo.SysOaAiMessageBo; import com.ruoyi.oa.service.ISysOaAiConversationService; import com.ruoyi.oa.service.ISysOaAiMessageService; import com.ruoyi.oa.service.ISysOaAiConfigService; +import com.ruoyi.oa.service.IAiDataQueryService; import com.ruoyi.oa.utils.AiServiceUtil; import com.ruoyi.common.core.page.TableDataInfo; +import lombok.extern.slf4j.Slf4j; /** * AI对话管理 @@ -39,12 +41,14 @@ import com.ruoyi.common.core.page.TableDataInfo; @RequiredArgsConstructor @RestController @RequestMapping("/oa/ai") +@Slf4j public class SysOaAiController extends BaseController { private final ISysOaAiConversationService conversationService; private final ISysOaAiMessageService messageService; private final ISysOaAiConfigService configService; private final AiServiceUtil aiServiceUtil; + private final IAiDataQueryService aiDataQueryService; /** * 查询AI对话历史列表 @@ -275,4 +279,39 @@ public class SysOaAiController extends BaseController { private String callAiService(String message) { return aiServiceUtil.callDeepSeek(message); } + + /** + * AI数据查询接口 + * 根据用户自然语言描述查询数据库数据 + */ + @PostMapping("/data-query") + public R queryDataByAi(@Validated @RequestBody com.ruoyi.oa.domain.bo.OaAiDataQueryBo queryBo) { + try { + com.ruoyi.oa.domain.vo.DynamicDataVo result = aiDataQueryService.queryDataByAi(queryBo); + return R.ok(result); + } catch (Exception e) { + log.error("AI数据查询失败", e); + return R.fail("AI数据查询失败: " + e.getMessage()); + } + } + + /** + * 根据关键词匹配相关表 + */ + @PostMapping("/match-tables") + public R> matchTablesByKeywords(@RequestBody java.util.Map request) { + try { + String[] keywords = request.get("keywords"); + if (keywords == null || keywords.length == 0) { + return R.fail("关键词不能为空"); + } + java.util.List matchedTables = aiDataQueryService.matchTablesByKeywords(keywords); + return R.ok(matchedTables); + } catch (Exception e) { + log.error("表匹配失败", e); + return R.fail("表匹配失败: " + e.getMessage()); + } + } + + } \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/bo/OaAiDataQueryBo.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/bo/OaAiDataQueryBo.java new file mode 100644 index 0000000..db44241 --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/bo/OaAiDataQueryBo.java @@ -0,0 +1,31 @@ +package com.ruoyi.oa.domain.bo; + +import lombok.Data; + +import javax.validation.constraints.NotBlank; + +/** + * AI数据查询请求 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Data +public class OaAiDataQueryBo { + + /** + * 用户查询需求描述 + */ + @NotBlank(message = "查询需求不能为空") + private String query; + + /** + * 最大返回记录数(默认100) + */ + private Integer limit = 100; + + /** + * 是否包含字段元信息(默认true) + */ + private Boolean includeMeta = true; +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/DynamicDataVo.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/DynamicDataVo.java new file mode 100644 index 0000000..2375fa5 --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/DynamicDataVo.java @@ -0,0 +1,63 @@ +package com.ruoyi.oa.domain.vo; + +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * 动态数据返回格式 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Data +public class DynamicDataVo { + + /** + * 元数据信息 + */ + private Meta meta; + + /** + * 实际数据 + */ + private List> data; + + /** + * 元数据类 + */ + @Data + public static class Meta { + /** + * 字段信息列表 + */ + private List fields; + } + + /** + * 字段信息类 + */ + @Data + public static class Field { + /** + * 字段名(后端返回的key) + */ + private String fieldName; + + /** + * 字段中文名称(用于前端显示列名) + */ + private String label; + + /** + * 数据类型(number/string/date/bool等,前端据此处理渲染) + */ + private String type; + + /** + * 格式化规则(如数字千分位、日期格式等,可选) + */ + private String format; + } +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/TableColumnVo.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/TableColumnVo.java new file mode 100644 index 0000000..da7cec3 --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/domain/vo/TableColumnVo.java @@ -0,0 +1,78 @@ +package com.ruoyi.oa.domain.vo; + +import lombok.Data; + +/** + * 数据库表字段信息 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Data +public class TableColumnVo { + + /** + * 字段名 + */ + private String columnName; + + /** + * 是否必填 + */ + private String isRequired; + + /** + * 是否主键 + */ + private String isPk; + + /** + * 排序 + */ + private Integer sort; + + /** + * 字段注释 + */ + private String columnComment; + + /** + * 是否自增 + */ + private String isIncrement; + + /** + * 字段类型 + */ + private String columnType; + + /** + * 数据类型 + */ + private String dataType; + + /** + * 字符最大长度 + */ + private Long characterMaximumLength; + + /** + * 数值精度 + */ + private Integer numericPrecision; + + /** + * 数值小数位数 + */ + private Integer numericScale; + + /** + * 默认值 + */ + private String columnDefault; + + /** + * 是否可为空 + */ + private String isNullable; +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/enums/TableMappingEnum.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/enums/TableMappingEnum.java new file mode 100644 index 0000000..e71717a --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/enums/TableMappingEnum.java @@ -0,0 +1,214 @@ +package com.ruoyi.oa.enums; + +import lombok.Getter; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * 数据库表名映射枚举 + * 用于将中文描述映射到对应的数据库表名 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Getter +public enum TableMappingEnum { + + // 员工相关表 + EMPLOYEE_FILES("employee_files", "员工档案表", new String[]{"员工文件", "档案"}), + EMPLOYEE_OFFBOARDING("employee_offboarding", "员工离职表", new String[]{"离职", "员工离职"}), + EMPLOYEE_ONBOARDING("employee_onboarding", "员工入职表", new String[]{"入职", "员工入职"}), + + // 文章相关表 + EXPORT_ARTICLE("export_article", "文章表", new String[]{"文章", "内容"}), + EXPORT_ARTICLE_CATEGORY("export_article_category", "文章分类表", new String[]{"文章分类", "分类"}), + EXPORT_CAROUSEL("export_carousel", "轮播图表", new String[]{"轮播图", "轮播"}), + EXPORT_CATEGORY("export_category", "分类表", new String[]{"分类", "类别"}), + EXPORT_CONTACT("export_contact", "联系方式表", new String[]{"联系方式", "联系"}), + EXPORT_ITEM("export_item", "展示品表", new String[]{"展示品", "展示"}), + EXPORT_LANGUAGE("export_language", "语言管理表", new String[]{"语言", "多语言"}), + + // OA应用相关表 + OA_APPLICATION("oa_application", "应用集成表", new String[]{"应用", "集成"}), + OA_ATTENDANCE_RECORD("oa_attendance_record", "考勤记录表", new String[]{"考勤", "打卡", "出勤"}), + OA_BINDING_ITEM_DETAIL("oa_binding_item_detail", "绑定记录明细表", new String[]{"绑定", "明细"}), + OA_BUSINESS("oa_business", "CRM商机表", new String[]{"商机", "业务", "CRM"}), + OA_BUSINESS_PRODUCT("oa_business_product", "CRM商机产品关联表", new String[]{"商机产品", "产品关联"}), + OA_CLUE("oa_clue", "CRM线索表", new String[]{"线索", "客户线索"}), + OA_CUSTOMER("oa_customer", "CRM客户表", new String[]{"客户", "客户信息"}), + OA_EMAIL_ACCOUNT("oa_email_account", "发件人邮箱账号管理", new String[]{"邮箱", "邮件账号"}), + OA_EMAIL_TEMPLATE("oa_email_template", "邮件模板表", new String[]{"邮件模板", "模板"}), + OA_EMPLOYEE("oa_employee", "员工基础信息", new String[]{"员工", "员工信息"}), + OA_EMPLOYEE_TEMPLATE_BINDING("oa_employee_template_binding", "员工模板绑定及月度发放记录表", new String[]{"员工模板", "绑定"}), + OA_EXPRESS("oa_express", "快递表", new String[]{"快递", "物流"}), + OA_EXPRESS_QUESTION("oa_express_question", "快递问题表", new String[]{"快递问题", "问题"}), + OA_FEEDBACK("oa_feedback", "反馈表", new String[]{"反馈", "意见"}), + OA_FEEDBACK_ITEM("oa_feedback_item", "反馈项目表", new String[]{"反馈项目", "反馈详情"}), + OA_FOLLOW_UP_RECORD("oa_follow_up_record", "CRM跟进记录", new String[]{"跟进", "跟进记录"}), + OA_FURNITURE_TABLE("oa_furniture_table", "存储家具相关业务数据", new String[]{"家具", "家具数据"}), + OA_INSURANCE_TEMPLATE("oa_insurance_template", "社保公积金模板主表", new String[]{"社保", "公积金", "保险"}), + OA_INSURANCE_TEMPLATE_DETAIL("oa_insurance_template_detail", "社保公积金模板明细表", new String[]{"社保明细", "公积金明细"}), + OA_PAYMENT_PROGRESS("oa_payment_progress", "项目付款进度表", new String[]{"付款", "付款进度"}), + OA_PRODUCT("oa_product", "CRM产品表", new String[]{"产品", "产品信息"}), + OA_PROGRESS("oa_progress", "项目进度主表", new String[]{"项目进度", "进度"}), + OA_PROGRESS_DETAIL("oa_progress_detail", "项目进度付款进度扩展表", new String[]{"进度详情", "付款详情"}), + OA_PROJECT_REPORT("oa_project_report", "报工记录表", new String[]{"报工", "报工记录"}), + OA_PROJECT_SCHEDULE("oa_project_schedule", "项目进度主表", new String[]{"项目排期", "排期"}), + OA_PROJECT_SCHEDULE_STEP("oa_project_schedule_step", "项目进度步骤跟踪表", new String[]{"项目步骤", "步骤跟踪"}), + OA_REPORT_DETAIL("oa_report_detail", "设计项目汇报详情表", new String[]{"汇报详情", "项目汇报"}), + OA_REPORT_SCHEDULE("oa_report_schedule", "项目排产表", new String[]{"排产", "生产排期"}), + OA_REPORT_SUMMARY("oa_report_summary", "设计项目汇报概述表", new String[]{"汇报概述", "项目概述"}), + OA_REQUIREMENTS("oa_requirements", "OA需求表", new String[]{"需求", "OA需求"}), + OA_SALARY("oa_salary", "薪水表", new String[]{"薪水", "工资", "薪资"}), + OA_SALARY_ITEM("oa_salary_item", "薪水详情表", new String[]{"薪水详情", "工资详情"}), + OA_SALARY_TEMPLATE("oa_salary_template", "薪资模板主表", new String[]{"薪资模板", "工资模板"}), + OA_SALARY_TEMPLATE_DETAIL("oa_salary_template_detail", "薪资模板明细表", new String[]{"薪资模板明细", "工资模板明细"}), + OA_SCHEDULE_TEMPLATE("oa_schedule_template", "进度模板主表", new String[]{"进度模板", "排期模板"}), + OA_SCHEDULE_TEMPLATE_STEP("oa_schedule_template_step", "进度模板步骤表", new String[]{"进度模板步骤", "模板步骤"}), + + // 通信相关表 + SOCKET_CONTACT("socket_contact", "通信目录表", new String[]{"通信", "联系人"}), + SOCKET_MESSAGE("socket_message", "对话信息表", new String[]{"对话", "消息"}), + + // 系统相关表 + SYS_CONFIG("sys_config", "参数配置表", new String[]{"配置", "系统配置"}), + SYS_DEPT("sys_dept", "部门表", new String[]{"部门", "组织架构"}), + SYS_DICT_DATA("sys_dict_data", "字典数据表", new String[]{"字典数据", "数据字典"}), + SYS_DICT_TYPE("sys_dict_type", "字典类型表", new String[]{"字典类型", "字典"}), + SYS_LOGININFOR("sys_logininfor", "系统访问记录", new String[]{"登录记录", "访问记录"}), + SYS_MENU("sys_menu", "菜单权限表", new String[]{"菜单", "权限"}), + SYS_NOTICE("sys_notice", "通知公告表", new String[]{"通知", "公告"}), + SYS_OA_AI_CONFIG("sys_oa_ai_config", "AI配置表", new String[]{"AI配置", "人工智能配置"}), + SYS_OA_AI_CONVERSATION("sys_oa_ai_conversation", "AI对话历史表", new String[]{"AI对话", "对话历史"}), + SYS_OA_AI_MESSAGE("sys_oa_ai_message", "AI对话详情表", new String[]{"AI消息", "对话详情"}), + SYS_OA_ARTICLE("sys_oa_article", "知识库表", new String[]{"知识库", "知识"}), + SYS_OA_ATTENDANCE("sys_oa_attendance", "考勤表", new String[]{"考勤", "出勤"}), + SYS_OA_BID("sys_oa_bid", "投标管理表", new String[]{"投标", "招标"}), + SYS_OA_CATEGORY("sys_oa_category", "文章分类表", new String[]{"文章分类", "分类管理"}), + SYS_OA_CLAIM("sys_oa_claim", "报销表", new String[]{"报销", "费用报销"}), + SYS_OA_CLAIM_DETAIL("sys_oa_claim_detail", "报销明细表", new String[]{"报销明细", "费用明细"}), + SYS_OA_CONTRACT("sys_oa_contract", "合同表", new String[]{"合同", "合同管理"}), + SYS_OA_COST("sys_oa_cost", "成本表", new String[]{"成本", "费用"}), + SYS_OA_DETAIL("sys_oa_detail", "进出账明细表", new String[]{"进出账明细", "账目明细"}), + SYS_OA_FINANCE("sys_oa_finance", "进出账主表", new String[]{"进出账", "财务"}), + SYS_OA_HOLIDAY("sys_oa_holiday", "节假日表", new String[]{"节假日", "假期"}), + SYS_OA_PROJECT("sys_oa_project", "项目管理表", new String[]{"项目", "项目管理"}), + SYS_OA_PURPOSE("sys_oa_purpose", "采购意向表", new String[]{"采购意向", "采购"}), + SYS_OA_RECEIVE_ACCOUNT("sys_oa_receive_account", "收款账户表", new String[]{"收款账户", "账户"}), + SYS_OA_REMIND("sys_oa_remind", "任务事件提醒表", new String[]{"提醒", "任务提醒"}), + SYS_OA_TASK("sys_oa_task", "任务管理表", new String[]{"任务", "任务管理"}), + SYS_OA_TASK_ITEM("sys_oa_task_item", "报工任务单元", new String[]{"任务单元", "报工单元"}), + SYS_OA_TASK_USER("sys_oa_task_user", "任务用户表", new String[]{"任务用户", "用户任务"}), + SYS_OA_WAREHOUSE("sys_oa_warehouse", "仓库表", new String[]{"仓库", "库存"}), + SYS_OA_WAREHOUSE_DETAIL("sys_oa_warehouse_detail", "仓库明细表", new String[]{"仓库明细", "库存明细"}), + SYS_OA_WAREHOUSE_LOG("sys_oa_warehouse_log", "仓库日志表", new String[]{"仓库日志", "库存日志"}), + SYS_OA_WAREHOUSE_MASTER("sys_oa_warehouse_master", "仓库主表", new String[]{"仓库主表", "库存主表"}), + SYS_OA_WAREHOUSE_TASK("sys_oa_warehouse_task", "采购计划表", new String[]{"采购计划", "采购任务"}), + SYS_OA_WORK("sys_oa_work", "工作表", new String[]{"工作", "工作记录"}), + SYS_OPER_LOG("sys_oper_log", "操作日志记录", new String[]{"操作日志", "日志"}), + SYS_OSS("sys_oss", "OSS对象存储表", new String[]{"文件存储", "对象存储"}), + SYS_OSS_ACL("sys_oss_acl", "OSS文件授权表", new String[]{"文件授权", "存储授权"}), + SYS_OSS_CONFIG("sys_oss_config", "对象存储配置表", new String[]{"存储配置", "OSS配置"}), + SYS_POST("sys_post", "岗位信息表", new String[]{"岗位", "职位"}), + SYS_PREFIX_COUNTER("sys_prefix_counter", "前缀计数器表", new String[]{"计数器", "前缀"}), + SYS_ROLE("sys_role", "角色信息表", new String[]{"角色", "权限角色"}), + SYS_ROLE_DEPT("sys_role_dept", "角色和部门关联表", new String[]{"角色部门", "部门角色"}), + SYS_ROLE_MENU("sys_role_menu", "角色和菜单关联表", new String[]{"角色菜单", "菜单角色"}), + SYS_USER("sys_user", "用户信息表", new String[]{"用户", "用户信息"}), + SYS_USER_POST("sys_user_post", "用户与岗位关联表", new String[]{"用户岗位", "岗位用户"}), + SYS_USER_ROLE("sys_user_role", "用户和角色关联表", new String[]{"用户角色", "角色用户"}), + + // 工作流相关表 + WF_CATEGORY("wf_category", "流程分类表", new String[]{"流程分类", "工作流分类"}), + WF_COPY("wf_copy", "流程抄送表", new String[]{"流程抄送", "抄送"}), + WF_DEPLOY_FORM("wf_deploy_form", "流程实例关联表单", new String[]{"流程表单", "部署表单"}), + WF_FORM("wf_form", "流程表单信息表", new String[]{"流程表单", "表单信息"}); + + private final String tableName; + private final String description; + private final String[] keywords; + + TableMappingEnum(String tableName, String description, String[] keywords) { + this.tableName = tableName; + this.description = description; + this.keywords = keywords; + } + + /** + * 根据关键词匹配表名 + * + * @param keyword 关键词 + * @return 匹配的表名列表 + */ + public static List getTableNamesByKeyword(String keyword) { + if (keyword == null || keyword.trim().isEmpty()) { + return new java.util.ArrayList<>(); + } + + String lowerKeyword = keyword.toLowerCase(); + return Arrays.stream(values()) + .filter(table -> { + // 检查表名是否包含关键词 + if (table.getTableName().toLowerCase().contains(lowerKeyword)) { + return true; + } + // 检查描述是否包含关键词 + if (table.getDescription().toLowerCase().contains(lowerKeyword)) { + return true; + } + // 检查关键词数组是否包含关键词 + return Arrays.stream(table.getKeywords()) + .anyMatch(kw -> kw.toLowerCase().contains(lowerKeyword)); + }) + .map(TableMappingEnum::getTableName) + .collect(Collectors.toList()); + } + + /** + * 根据多个关键词匹配表名 + * + * @param keywords 关键词数组 + * @return 匹配的表名列表 + */ + public static List getTableNamesByKeywords(String[] keywords) { + if (keywords == null || keywords.length == 0) { + return new java.util.ArrayList<>(); + } + + java.util.Set matchedTables = new java.util.HashSet<>(); + for (String keyword : keywords) { + matchedTables.addAll(getTableNamesByKeyword(keyword)); + } + return new java.util.ArrayList<>(matchedTables); + } + + /** + * 获取所有表名和描述的映射 + * + * @return 表名到描述的映射 + */ + public static Map getAllTableDescriptions() { + return Arrays.stream(values()) + .collect(Collectors.toMap( + TableMappingEnum::getTableName, + TableMappingEnum::getDescription + )); + } + + /** + * 根据表名获取描述 + * + * @param tableName 表名 + * @return 描述 + */ + public static String getDescriptionByTableName(String tableName) { + return Arrays.stream(values()) + .filter(table -> table.getTableName().equals(tableName)) + .findFirst() + .map(TableMappingEnum::getDescription) + .orElse(""); + } +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IAiDataQueryService.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IAiDataQueryService.java new file mode 100644 index 0000000..e50f96c --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IAiDataQueryService.java @@ -0,0 +1,40 @@ +package com.ruoyi.oa.service; + +import com.ruoyi.oa.domain.bo.OaAiDataQueryBo; +import com.ruoyi.oa.domain.vo.DynamicDataVo; + +/** + * AI数据查询服务接口 + * + * @author ruoyi + * @date 2024-12-19 + */ +public interface IAiDataQueryService { + + /** + * 根据用户需求进行AI数据查询 + * + * @param queryBo 查询请求 + * @return 动态数据格式的查询结果 + */ + DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo); + + /** + * 根据关键词匹配相关表 + * + * @param keywords 关键词 + * @return 匹配的表名列表 + */ + java.util.List matchTablesByKeywords(String[] keywords); + + /** + * 生成AI提示词 + * + * @param userQuery 用户查询需求 + * @param tableNames 相关表名 + * @param tableColumns 表字段信息 + * @return AI提示词 + */ + String generateAiPrompt(String userQuery, java.util.List tableNames, + java.util.Map> tableColumns); +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IDatabaseQueryService.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IDatabaseQueryService.java new file mode 100644 index 0000000..9106e79 --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/IDatabaseQueryService.java @@ -0,0 +1,58 @@ +package com.ruoyi.oa.service; + +import com.ruoyi.oa.domain.vo.DynamicDataVo; +import com.ruoyi.oa.domain.vo.TableColumnVo; + +import java.util.List; +import java.util.Map; + +/** + * 数据库查询服务接口 + * + * @author ruoyi + * @date 2024-12-19 + */ +public interface IDatabaseQueryService { + + /** + * 根据表名获取表字段信息 + * + * @param tableName 表名 + * @return 字段信息列表 + */ + List getTableColumns(String tableName); + + /** + * 根据多个表名获取表字段信息 + * + * @param tableNames 表名列表 + * @return 表名到字段信息的映射 + */ + Map> getTableColumns(List tableNames); + + /** + * 执行SQL查询 + * + * @param sql SQL语句 + * @return 查询结果 + */ + List> executeQuery(String sql); + + /** + * 执行SQL查询并返回动态数据格式 + * + * @param sql SQL语句 + * @param tableName 表名(用于生成字段元信息) + * @param includeMeta 是否包含元信息 + * @return 动态数据格式 + */ + DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta); + + /** + * 验证SQL是否为查询语句 + * + * @param sql SQL语句 + * @return 是否为查询语句 + */ + boolean isQuerySql(String sql); +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/AiDataQueryServiceImpl.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/AiDataQueryServiceImpl.java new file mode 100644 index 0000000..5363490 --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/AiDataQueryServiceImpl.java @@ -0,0 +1,350 @@ +package com.ruoyi.oa.service.impl; + +import com.ruoyi.oa.domain.bo.OaAiDataQueryBo; +import com.ruoyi.oa.domain.vo.DynamicDataVo; +import com.ruoyi.oa.domain.vo.TableColumnVo; +import com.ruoyi.oa.enums.TableMappingEnum; +import com.ruoyi.oa.service.IAiDataQueryService; +import com.ruoyi.oa.service.IDatabaseQueryService; +import com.ruoyi.oa.utils.AiServiceUtil; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * AI数据查询服务实现类 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class AiDataQueryServiceImpl implements IAiDataQueryService { + + private final IDatabaseQueryService databaseQueryService; + private final AiServiceUtil aiServiceUtil; + + @Override + public DynamicDataVo queryDataByAi(OaAiDataQueryBo queryBo) { + try { + // 1. 从用户查询中提取关键词 + String[] keywords = extractKeywords(queryBo.getQuery()); + + // 2. 根据关键词匹配相关表 + List matchedTables = matchTablesByKeywords(keywords); + + // 3. 如果关键词匹配失败,使用AI智能识别表 + if (matchedTables.isEmpty()) { + log.info("关键词匹配失败,使用AI智能识别表"); + matchedTables = identifyTablesByAi(queryBo.getQuery()); + + if (matchedTables.isEmpty()) { + throw new RuntimeException("未找到与查询需求相关的数据表"); + } + } + + // 4. 获取匹配表的字段信息 + Map> tableColumns = databaseQueryService.getTableColumns(matchedTables); + + // 5. 生成AI提示词 + String prompt = generateAiPrompt(queryBo.getQuery(), matchedTables, tableColumns); + + // 6. 调用AI生成SQL + String aiResponse = aiServiceUtil.callDeepSeek(prompt); + + // 7. 提取SQL语句 + String sql = extractSqlFromAiResponse(aiResponse); + + // 8. 执行SQL查询 + String primaryTable = matchedTables.get(0); // 使用第一个匹配的表作为主表 + DynamicDataVo result = databaseQueryService.executeQueryWithMeta(sql, primaryTable, queryBo.getIncludeMeta()); + + // 9. 限制返回记录数 + if (result.getData() != null && result.getData().size() > queryBo.getLimit()) { + result.setData(result.getData().subList(0, queryBo.getLimit())); + } + + return result; + + } catch (Exception e) { + log.error("AI数据查询失败", e); + throw new RuntimeException("AI数据查询失败: " + e.getMessage()); + } + } + + @Override + public List matchTablesByKeywords(String[] keywords) { + return TableMappingEnum.getTableNamesByKeywords(keywords); + } + + @Override + public String generateAiPrompt(String userQuery, List tableNames, + Map> tableColumns) { + StringBuilder prompt = new StringBuilder(); + + // 系统角色设定 + prompt.append("你是一个专业的数据库查询专家,请根据用户的需求和提供的数据库表结构信息,生成相应的SQL查询语句。\n\n"); + + // 重要规则 + prompt.append("重要规则:\n"); + prompt.append("1. 只能生成SELECT查询语句,严禁生成INSERT、UPDATE、DELETE、DROP、CREATE、ALTER等修改数据的语句\n"); + prompt.append("2. 生成的SQL必须语法正确,可以直接执行\n"); + prompt.append("3. 如果涉及多表查询,请使用适当的JOIN语句\n"); + prompt.append("4. 请根据用户需求添加适当的WHERE条件、ORDER BY、LIMIT等子句\n"); + prompt.append("5. 只返回SQL语句,不要包含其他解释文字\n\n"); + + // 用户需求 + prompt.append("用户需求:").append(userQuery).append("\n\n"); + + // 相关表信息 + prompt.append("相关数据表信息:\n"); + for (String tableName : tableNames) { + prompt.append("表名:").append(tableName).append("\n"); + prompt.append("表描述:").append(TableMappingEnum.getDescriptionByTableName(tableName)).append("\n"); + + List columns = tableColumns.get(tableName); + if (columns != null) { + prompt.append("字段信息:\n"); + for (TableColumnVo column : columns) { + prompt.append(" - ").append(column.getColumnName()) + .append(" (").append(column.getColumnType()).append(")") + .append(" - ").append(column.getColumnComment() != null ? column.getColumnComment() : "无注释"); + + if ("1".equals(column.getIsPk())) { + prompt.append(" [主键]"); + } + if ("1".equals(column.getIsRequired())) { + prompt.append(" [必填]"); + } + prompt.append("\n"); + } + } + prompt.append("\n"); + } + + // 输出要求 + prompt.append("请根据以上信息生成SQL查询语句,只返回SQL语句本身,不要包含任何其他文字。"); + + return prompt.toString(); + } + + /** + * 从用户查询中提取关键词 + */ + private String[] extractKeywords(String query) { + if (query == null || query.trim().isEmpty()) { + return new String[0]; + } + + // 移除常见的查询词汇 + String cleanedQuery = query.replaceAll("查询|获取|显示|查看|搜索|查找|统计|分析|的|表|情况|信息|数据", ""); + + // 按空格、标点符号分割 + String[] words = cleanedQuery.split("[\\s,,。!?!?]+"); + + // 过滤掉空字符串和太短的词 + List keywords = new ArrayList<>(); + for (String word : words) { + if (word != null && word.trim().length() >= 2) { + keywords.add(word.trim()); + } + } + + // 如果没有提取到关键词,尝试更简单的分割方式 + if (keywords.isEmpty()) { + // 直接按字符分割,提取有意义的词组 + String[] simpleWords = cleanedQuery.split(""); + StringBuilder currentWord = new StringBuilder(); + for (String charStr : simpleWords) { + if (charStr.matches("[\\u4e00-\\u9fa5]")) { // 中文字符 + currentWord.append(charStr); + } else { + if (currentWord.length() >= 2) { + keywords.add(currentWord.toString()); + } + currentWord.setLength(0); + } + } + // 处理最后一个词 + if (currentWord.length() >= 2) { + keywords.add(currentWord.toString()); + } + } + + return keywords.toArray(new String[0]); + } + + /** + * 使用AI智能识别相关的数据表 + */ + private List identifyTablesByAi(String userQuery) { + try { + // 从数据库中获取所有可用的表信息 + Map allTableDescriptions = getAvailableTablesFromDatabase(); + + if (allTableDescriptions.isEmpty()) { + log.warn("未从数据库获取到任何表信息"); + return new ArrayList<>(); + } + + // 构建AI提示词 + StringBuilder prompt = new StringBuilder(); + prompt.append("你是一个数据库专家,请根据用户的查询需求,从以下数据库表中选择最相关的表名。\n\n"); + prompt.append("用户查询:").append(userQuery).append("\n\n"); + prompt.append("可用的数据库表:\n"); + + for (Map.Entry entry : allTableDescriptions.entrySet()) { + prompt.append("- ").append(entry.getKey()).append(" (").append(entry.getValue()).append(")\n"); + } + + prompt.append("\n请根据用户查询需求,返回最相关的表名(用逗号分隔,最多返回3个表)。"); + prompt.append("只返回表名,不要包含其他解释文字。"); + + // 调用AI获取表名 + String aiResponse = aiServiceUtil.callDeepSeek(prompt.toString()); + + // 解析AI返回的表名 + List identifiedTables = parseTableNamesFromAiResponse(aiResponse, allTableDescriptions.keySet()); + + log.info("AI识别的表:{}", identifiedTables); + return identifiedTables; + + } catch (Exception e) { + log.error("AI表识别失败", e); + return new ArrayList<>(); + } + } + + /** + * 从数据库中获取所有可用的表信息 + */ + private Map getAvailableTablesFromDatabase() { + try { + String sql = "select table_name, table_comment from information_schema.tables " + + "where table_schema = (select database()) " + + "AND table_name NOT LIKE 'ACT_%' " + + "AND table_name NOT LIKE 'xxl_job_%' " + + "AND table_name NOT LIKE 'FLW_%' " + + "AND table_name NOT LIKE 'gen_%' " + + "order by table_name"; + + List> results = databaseQueryService.executeQuery(sql); + Map tableDescriptions = new HashMap<>(); + + for (Map row : results) { + String tableName = (String) row.get("table_name"); + String tableComment = (String) row.get("table_comment"); + + if (tableName != null) { + // 如果表注释为空,使用枚举中的描述作为备选 + String description = (tableComment != null && !tableComment.trim().isEmpty()) + ? tableComment + : TableMappingEnum.getDescriptionByTableName(tableName); + + // 如果枚举中也没有描述,使用表名作为描述 + if (description == null || description.trim().isEmpty()) { + description = tableName; + } + + tableDescriptions.put(tableName, description); + } + } + + log.info("从数据库获取到 {} 个表", tableDescriptions.size()); + return tableDescriptions; + + } catch (Exception e) { + log.error("从数据库获取表信息失败", e); + // 如果数据库查询失败,回退到枚举中的表 + return TableMappingEnum.getAllTableDescriptions(); + } + } + + /** + * 从AI响应中解析表名 + */ + private List parseTableNamesFromAiResponse(String aiResponse, Set validTableNames) { + List identifiedTables = new ArrayList<>(); + + if (aiResponse == null || aiResponse.trim().isEmpty()) { + return identifiedTables; + } + + // 清理AI响应 + String cleanedResponse = aiResponse.trim(); + + // 如果响应包含```标记,提取其中的内容 + if (cleanedResponse.contains("```")) { + int start = cleanedResponse.indexOf("```") + 3; + int end = cleanedResponse.indexOf("```", start); + if (end > start) { + cleanedResponse = cleanedResponse.substring(start, end).trim(); + } + } + + // 按逗号、分号、换行符分割 + String[] tableNames = cleanedResponse.split("[,,;;\\n\\r]+"); + + for (String tableName : tableNames) { + String trimmedName = tableName.trim(); + // 移除可能的括号内容 + if (trimmedName.contains("(")) { + trimmedName = trimmedName.substring(0, trimmedName.indexOf("(")).trim(); + } + + // 验证表名是否有效 + if (validTableNames.contains(trimmedName)) { + identifiedTables.add(trimmedName); + } + } + + return identifiedTables; + } + + /** + * 从AI响应中提取SQL语句 + */ + private String extractSqlFromAiResponse(String aiResponse) { + if (aiResponse == null || aiResponse.trim().isEmpty()) { + throw new RuntimeException("AI未返回有效的SQL语句"); + } + + log.info("AI原始响应: {}", aiResponse); + + // 清理AI响应,提取SQL语句 + String cleanedResponse = aiResponse.trim(); + + // 如果响应包含```sql和```标记,提取其中的内容 + if (cleanedResponse.contains("```sql")) { + int start = cleanedResponse.indexOf("```sql") + 6; + int end = cleanedResponse.indexOf("```", start); + if (end > start) { + cleanedResponse = cleanedResponse.substring(start, end).trim(); + } + } else if (cleanedResponse.contains("```")) { + int start = cleanedResponse.indexOf("```") + 3; + int end = cleanedResponse.indexOf("```", start); + if (end > start) { + cleanedResponse = cleanedResponse.substring(start, end).trim(); + } + } + + log.info("提取的SQL: {}", cleanedResponse); + + // 验证是否为有效的SQL查询语句 + if (!databaseQueryService.isQuerySql(cleanedResponse)) { + log.error("SQL验证失败: {}", cleanedResponse); + throw new RuntimeException("AI生成的SQL语句不符合安全要求: " + cleanedResponse); + } + + log.info("SQL验证通过: {}", cleanedResponse); + return cleanedResponse; + } +} \ No newline at end of file diff --git a/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/DatabaseQueryServiceImpl.java b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/DatabaseQueryServiceImpl.java new file mode 100644 index 0000000..45c7bfe --- /dev/null +++ b/ruoyi-oa/src/main/java/com/ruoyi/oa/service/impl/DatabaseQueryServiceImpl.java @@ -0,0 +1,204 @@ +package com.ruoyi.oa.service.impl; + +import com.ruoyi.common.helper.DataBaseHelper; +import com.ruoyi.oa.domain.vo.DynamicDataVo; +import com.ruoyi.oa.domain.vo.TableColumnVo; +import com.ruoyi.oa.service.IDatabaseQueryService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Service; + +import java.util.*; +import java.util.regex.Pattern; + +/** + * 数据库查询服务实现类 + * + * @author ruoyi + * @date 2024-12-19 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class DatabaseQueryServiceImpl implements IDatabaseQueryService { + + private final JdbcTemplate jdbcTemplate; + + // 优化后的SQL查询语句验证正则表达式 + // 确保能匹配SELECT/WITH开头,无论后面是空格、换行还是其他空白字符 + private static final Pattern QUERY_PATTERN = Pattern.compile( + "^\\s*(SELECT|WITH)\\s+", + Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL + ); + + // 同时检查禁止关键字列表,确保没有包含正常查询所需的关键字 + private static final List FORBIDDEN_KEYWORDS = Arrays.asList( + "INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "TRUNCATE", + "CREATE", "RENAME", "GRANT", "REVOKE", "EXEC", "CALL", "MERGE" + ); + + @Override + public List getTableColumns(String tableName) { + if (!DataBaseHelper.isMySql()) { + throw new UnsupportedOperationException("目前只支持MySQL数据库"); + } + + String sql = "select column_name,\n" + + " (case when (is_nullable = 'no' && column_key != 'PRI') then '1' else null end) as is_required,\n" + + " (case when column_key = 'PRI' then '1' else '0' end) as is_pk,\n" + + " ordinal_position as sort,\n" + + " column_comment,\n" + + " (case when extra = 'auto_increment' then '1' else '0' end) as is_increment,\n" + + " column_type,\n" + + " data_type,\n" + + " character_maximum_length,\n" + + " numeric_precision,\n" + + " numeric_scale,\n" + + " column_default,\n" + + " is_nullable\n" + + "from information_schema.columns \n" + + "where table_schema = (select database()) \n" + + "and table_name = ?\n" + + "order by ordinal_position"; + + return jdbcTemplate.query(sql, (rs, rowNum) -> { + TableColumnVo column = new TableColumnVo(); + column.setColumnName(rs.getString("column_name")); + column.setIsRequired(rs.getString("is_required")); + column.setIsPk(rs.getString("is_pk")); + column.setSort(rs.getInt("sort")); + column.setColumnComment(rs.getString("column_comment")); + column.setIsIncrement(rs.getString("is_increment")); + column.setColumnType(rs.getString("column_type")); + column.setDataType(rs.getString("data_type")); + column.setCharacterMaximumLength(rs.getLong("character_maximum_length")); + column.setNumericPrecision(rs.getInt("numeric_precision")); + column.setNumericScale(rs.getInt("numeric_scale")); + column.setColumnDefault(rs.getString("column_default")); + column.setIsNullable(rs.getString("is_nullable")); + return column; + }, tableName); + } + + @Override + public Map> getTableColumns(List tableNames) { + Map> result = new HashMap<>(); + for (String tableName : tableNames) { + result.put(tableName, getTableColumns(tableName)); + } + return result; + } + + @Override + public List> executeQuery(String sql) { + // 验证SQL安全性 + if (!isQuerySql(sql)) { + throw new IllegalArgumentException("只允许执行查询语句"); + } + + log.info("执行SQL查询: {}", sql); + return jdbcTemplate.queryForList(sql); + } + + @Override + public DynamicDataVo executeQueryWithMeta(String sql, String tableName, boolean includeMeta) { + // 执行查询 + List> data = executeQuery(sql); + + DynamicDataVo result = new DynamicDataVo(); + result.setData(data); + + // 如果需要元信息,则生成字段元信息 + if (includeMeta && tableName != null) { + List columns = getTableColumns(tableName); + DynamicDataVo.Meta meta = new DynamicDataVo.Meta(); + List fields = new ArrayList<>(); + + for (TableColumnVo column : columns) { + DynamicDataVo.Field field = new DynamicDataVo.Field(); + field.setFieldName(column.getColumnName()); + field.setLabel(column.getColumnComment() != null ? column.getColumnComment() : column.getColumnName()); + field.setType(getFieldType(column)); + field.setFormat(getFieldFormat(column)); + fields.add(field); + } + + meta.setFields(fields); + result.setMeta(meta); + } + + return result; + } + + @Override + public boolean isQuerySql(String sql) { + if (sql == null || sql.trim().isEmpty()) { + log.warn("SQL为空"); + return false; + } + + // 首先尝试正则表达式匹配 + if (QUERY_PATTERN.matcher(sql).matches()) { + log.debug("正则表达式匹配成功"); + } else { + // 如果正则表达式匹配失败,使用简单的字符串检查作为备选 + String trimmedSql = sql.trim().toUpperCase(); + if (!trimmedSql.startsWith("SELECT") && !trimmedSql.startsWith("WITH")) { + log.warn("SQL不以SELECT或WITH开头: {}", sql); + return false; + } + log.debug("使用字符串检查匹配成功"); + } + + String upperSql = sql.toUpperCase(); + // 检查是否包含禁止的关键字(使用单词边界避免部分匹配) + for (String keyword : FORBIDDEN_KEYWORDS) { + if (Pattern.compile("\\b" + keyword + "\\b").matcher(upperSql).find()) { + log.warn("SQL包含禁止关键字 {}: {}", keyword, sql); + return false; + } + } + + log.debug("SQL验证通过: {}", sql); + return true; + } + + + /** + * 根据字段信息推断字段类型 + */ + private String getFieldType(TableColumnVo column) { + String dataType = column.getDataType().toLowerCase(); + + if (dataType.contains("int") || dataType.contains("decimal") || + dataType.contains("float") || dataType.contains("double")) { + return "number"; + } else if (dataType.contains("date") || dataType.contains("time")) { + return "date"; + } else if (dataType.contains("bool")) { + return "bool"; + } else { + return "string"; + } + } + + /** + * 根据字段信息生成格式化规则 + */ + private String getFieldFormat(TableColumnVo column) { + String dataType = column.getDataType().toLowerCase(); + + if (dataType.contains("int")) { + return "0,0"; + } else if (dataType.contains("decimal") || dataType.contains("float") || dataType.contains("double")) { + return "0.00"; + } else if (dataType.contains("date")) { + return "YYYY-MM-DD"; + } else if (dataType.contains("datetime") || dataType.contains("timestamp")) { + return "YYYY-MM-DD HH:mm:ss"; + } + + return null; + } +} \ No newline at end of file