人脸识别与大模型问答

This commit is contained in:
jiang 2024-10-16 09:23:17 +08:00
parent 96d23b943e
commit cd2120e78c
8 changed files with 247 additions and 65 deletions

View File

@ -41,7 +41,14 @@
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<!-- Mysql Connector -->
<dependency>
@ -99,7 +106,10 @@
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>
<build>

View File

@ -18,19 +18,23 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public String getOriginalFilename() {
return file.getName(); // 或根据需要返回不同的名称
return file.getName(); // 可以返回上传时的文件名
}
@Override
public String getContentType() {
// 根据文件扩展名返回内容类型
String name = file.getName();
String name = file.getName().toLowerCase();
if (name.endsWith(".jpg") || name.endsWith(".jpeg")) {
return "image/jpeg";
} else if (name.endsWith(".png")) {
return "image/png";
} else if (name.endsWith(".txt")) {
return "text/plain";
} else if (name.endsWith(".pdf")) {
return "application/pdf";
}
return null; // 或者抛出异常
return "application/octet-stream"; // 返回默认的二进制流类型
}
@Override
@ -45,9 +49,12 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public byte[] getBytes() throws IOException {
// 打开文件流并读取字节内容
try (FileInputStream fis = new FileInputStream(file)) {
byte[] bytes = new byte[(int) file.length()];
fis.read(bytes);
if (fis.read(bytes) == -1) {
throw new IOException("Unable to read the file content");
}
return bytes;
}
}
@ -59,13 +66,14 @@ public class CustomMultipartFile implements MultipartFile {
@Override
public void transferTo(File dest) throws IOException, IllegalStateException {
// 实现文件传输的逻辑
try (FileInputStream fis = new FileInputStream(file);
FileOutputStream fos = new FileOutputStream(dest)) {
byte[] buffer = new byte[1024];
int length;
while ((length = fis.read(buffer)) > 0) {
fos.write(buffer, 0, length);
// 可以使用 NIO Files.copy 方法来简化文件传输逻辑
try (InputStream inputStream = new FileInputStream(file)) {
try (OutputStream outputStream = new FileOutputStream(dest)) {
byte[] buffer = new byte[8192];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
}
}
}

View File

@ -1,26 +1,22 @@
package com.bonus.ai.controller;
import cn.hutool.core.io.IoUtil;
import com.bonus.ai.config.CustomMultipartFile;
import com.bonus.ai.domain.*;
import com.bonus.ai.service.DataSetService;
import com.bonus.common.core.utils.StringUtils;
import com.bonus.common.core.web.controller.BaseController;
import com.bonus.common.core.web.domain.AjaxResult;
import com.bonus.common.core.web.page.TableDataInfo;
import com.bonus.common.security.annotation.PreventRepeatSubmit;
import com.bonus.system.api.RemoteFileService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.poi.openxml4j.opc.internal.FileHelper;
import org.aspectj.util.FileUtil;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.commons.CommonsMultipartFile;
import javax.annotation.Resource;
import java.io.*;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
@ -44,6 +40,7 @@ public class DataSetController extends BaseController {
@Resource
private RemoteFileService remoteFileService;
/**
* 根据数据集 ID 查询对应的数据集信息
*
@ -127,6 +124,10 @@ public class DataSetController extends BaseController {
return dataSetService.getCategoryById(categoryId);
}
@PostMapping("/getCategories/{categoryId}")
public AjaxResult getCategoriesById(@PathVariable Long categoryId) {
return dataSetService.getCategoriesById(categoryId);
}
/**
* 获取所有符合条件的类别信息
*
@ -144,6 +145,7 @@ public class DataSetController extends BaseController {
* @param entity 要插入的 DataSetCategoryEntity 对象包含类别的详细信息
* @return 返回插入操作影响的行数成功插入的记录数通常为 1
*/
@PreventRepeatSubmit
@PostMapping("/insertCategory")
public AjaxResult insertCategory(@RequestBody DataSetCategoryEntity entity) {
return dataSetService.insertCategory(entity);
@ -490,41 +492,38 @@ public class DataSetController extends BaseController {
@PostMapping("/uploadZipFiles")
public AjaxResult uploadZipFiles(@RequestParam("files") MultipartFile[] files, @RequestParam("datasetId") Long datasetId) {
List<MultipartFile> multipartFiles = new ArrayList<>();
Path tempDir;
try {
// 创建一个临时目录用于存放解压后的文件
tempDir = Files.createTempDirectory("uploads");
// 遍历上传的压缩文件
for (MultipartFile file : files) {
processZipFile(file, multipartFiles, tempDir);
// 处理每个压缩文件并上传其中的文件
uploadFilesFromZip(file, datasetId, tempDir);
}
// 处理完成后删除临时文件夹
deleteTempDirectory(tempDir);
// List 转换为数组方便后续上传
MultipartFile[] resultFiles = multipartFiles.toArray(new MultipartFile[0]);
return handleFileUpload(resultFiles, datasetId);
return AjaxResult.success("上传成功");
} catch (IOException e) {
// 如果创建临时目录失败返回错误信息
return AjaxResult.error("Failed to create temporary directory: " + e.getMessage());
return AjaxResult.error("上传失败");
}
}
// 处理每个压缩文件解压并将图片文件添加到列表中
private void processZipFile(MultipartFile file, List<MultipartFile> multipartFiles, Path tempDir) {
// 从压缩文件中解压并上传文件
private void uploadFilesFromZip(MultipartFile file, Long datasetId, Path tempDir) {
List<MultipartFile> multipartFiles = new ArrayList<>();
int batchSize = 100; // 每批处理的文件数量
try (ZipInputStream zis = new ZipInputStream(file.getInputStream())) {
ZipEntry entry;
// 逐个读取压缩文件中的条目
while ((entry = zis.getNextEntry()) != null) {
// 如果条目不是目录且是图片文件
if (!entry.isDirectory() && isImageFile(entry.getName())) {
// 创建解压后的图片文
// 创建解压后的图片文
File imgFile = new File(tempDir.toFile(), entry.getName());
imgFile.getParentFile().mkdirs(); // 确保目录存在
// 将条目内容写入到文件
try (FileOutputStream fos = new FileOutputStream(imgFile)) {
byte[] buffer = new byte[1024];
@ -537,8 +536,27 @@ public class DataSetController extends BaseController {
// 创建 MultipartFile 对象并添加到列表中
MultipartFile multipartFile = new CustomMultipartFile(imgFile);
multipartFiles.add(multipartFile);
// 每处理 100 个文件时进行一次上传
if (multipartFiles.size() >= batchSize) {
// 上传当前批次的文件
AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir);
multipartFiles.clear(); // 清空已上传的文件列表
// 检查上传结果
if (!result.isSuccess()) {
throw new RuntimeException("File upload failed for " + file.getOriginalFilename());
}
}
}
zis.closeEntry(); // 关闭当前条目
}
// 处理剩余未上传的文件
if (!multipartFiles.isEmpty()) {
AjaxResult result = handleFileUpload(multipartFiles.toArray(new MultipartFile[0]), datasetId, tempDir);
if (!result.isSuccess()) {
throw new RuntimeException("File upload failed for " + file.getOriginalFilename());
}
}else {
deleteTempDirectory(tempDir);
}
} catch (IOException e) {
// 抛出运行时异常包含原始文件名的信息
@ -547,7 +565,7 @@ public class DataSetController extends BaseController {
}
// 处理文件上传及数据库插入
private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId) {
private AjaxResult handleFileUpload(MultipartFile[] resultFiles, Long datasetId, Path tempDir) throws IOException {
// 调用远程服务上传文件
AjaxResult ajaxResult = remoteFileService.uploadFile(resultFiles);
if (ajaxResult.isSuccess()) {
@ -564,7 +582,7 @@ public class DataSetController extends BaseController {
entity.setFileName(map.get("name"));
dataSetService.insertFile(entity);
}
// 返回成功结果
// deleteTempDirectory(tempDir); // 删除临时目录放在上传后再处理
return AjaxResult.success();
} else {
// 上传失败返回错误信息
@ -601,4 +619,5 @@ public class DataSetController extends BaseController {
e.printStackTrace(); // 打印错误信息
}
}
}

View File

@ -70,7 +70,7 @@ public interface DataSetMapper {
* @return 返回所有符合条件的类别的列表
*/
List<DataSetCategoryEntity> getCategories(DataSetCategoryEntity entity);
List<DataSetCategoryEntity> getCategoriesById(Long categoryId);
/**
* 插入新的样本类别数据到数据库
*
@ -79,6 +79,15 @@ public interface DataSetMapper {
*/
int insertCategory(DataSetCategoryEntity entity);
// 检查类别名称是否已存在
int countByNameAndParentId(@Param("categoryName") String categoryName, @Param("parentId") Long parentId, @Param("categoryId") Long categoryId);
// 获取所有祖先类别名称
List<DataSetCategoryEntity> getAncestorCategoryNames(@Param("categoryId") Long categoryId);
// 获取所有祖先类别名称
List<DataSetCategoryEntity> getAncestorCategoryName(@Param("parentId") Long parentId);
/**
* 更新指定类别的详细信息
*
@ -93,7 +102,9 @@ public interface DataSetMapper {
* @param categoryId 要删除的类别的唯一标识符
* @return 返回删除操作影响的行数成功删除的记录数通常为 1
*/
int deleteCategory(Long categoryId);
int deleteCategory(@Param("categoryIds") List<Long> categoryIds);
List<Long> getCategoryId(Long categoryId);
/**
* 根据 ID 查询数据集文件信息
@ -281,4 +292,6 @@ public interface DataSetMapper {
* @return 影响的行数
*/
int deleteAlgorithm(Long[] algorithmIds);
}

View File

@ -267,4 +267,6 @@ public interface DataSetService {
* @return 影响的行数
*/
AjaxResult deleteAlgorithm(Long[] algorithmIds);
AjaxResult getCategoriesById(Long categoryId);
}

View File

@ -4,15 +4,19 @@ import com.bonus.ai.domain.*;
import com.bonus.ai.mapper.DataSetMapper;
import com.bonus.ai.service.DataSetService;
import com.bonus.common.core.domain.R;
import com.bonus.common.core.utils.StringUtils;
import com.bonus.common.core.web.domain.AjaxResult;
import com.bonus.common.security.utils.SecurityUtils;
import com.bonus.system.api.RemoteFileService;
import com.bonus.system.api.domain.SysFile;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -138,6 +142,20 @@ public class DataSetServiceImpl implements DataSetService {
}
}
/**
* @param categoryId
* @return
*/
@Override
public AjaxResult getCategoriesById(Long categoryId) {
try {
List<DataSetCategoryEntity> categories = mapper.getCategoriesById(categoryId);
return AjaxResult.success(categories);
} catch (Exception e) {
return AjaxResult.error("获取数据失败");
}
}
/**
* 插入新的样本类别数据到数据库
*
@ -147,7 +165,20 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult insertCategory(DataSetCategoryEntity entity) {
try {
// 创建一个列表来存储所有要删除的类别ID
entity.setCreatedBy(SecurityUtils.getUserId());
List<DataSetCategoryEntity> ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId());
List<DataSetCategoryEntity> ancestorNames = new ArrayList<>(ancestorCategoryName);
int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId());
if (num > 0) {
return AjaxResult.error("新增失败,类别已存在");
}
List<DataSetCategoryEntity> ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames);
for (DataSetCategoryEntity e : ancestorCategoryNames) {
if (entity.getCategoryName().equals(e.getCategoryName())) {
return AjaxResult.error("子类别名称不能与任何祖先类别名称相同");
}
}
int i = mapper.insertCategory(entity);
if (i > 0) {
return AjaxResult.success("新增成功");
@ -159,6 +190,18 @@ public class DataSetServiceImpl implements DataSetService {
}
}
public List<DataSetCategoryEntity> findAncestorCategoryNames(Long categoryId, List<DataSetCategoryEntity> ancestorNames) {
// 获取当前类别的所有子类别ID
List<DataSetCategoryEntity> ancestorCategoryNames = mapper.getAncestorCategoryNames(categoryId);
// 将子类别ID加入列表
ancestorNames.addAll(ancestorCategoryNames);
// 递归调用获取每个子类别的子类别
for (DataSetCategoryEntity entity : ancestorCategoryNames) {
findAncestorCategoryNames(entity.getCategoryId(), ancestorNames);
}
return ancestorNames;
}
/**
* 更新指定类别的详细信息
*
@ -168,6 +211,18 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult updateCategory(DataSetCategoryEntity entity) {
try {
int num = mapper.countByNameAndParentId(entity.getCategoryName(), entity.getParentId(), entity.getCategoryId());
if (num > 0) {
return AjaxResult.error("修改失败,类别已存在");
}
List<DataSetCategoryEntity> ancestorCategoryName = mapper.getAncestorCategoryName(entity.getParentId());
List<DataSetCategoryEntity> ancestorNames = new ArrayList<>(ancestorCategoryName);
List<DataSetCategoryEntity> ancestorCategoryNames = findAncestorCategoryNames(entity.getParentId(), ancestorNames);
for (DataSetCategoryEntity e : ancestorCategoryNames) {
if (entity.getCategoryName().equals(e.getCategoryName())) {
return AjaxResult.error("子类别名称不能与任何祖先类别名称相同");
}
}
int i = mapper.updateCategory(entity);
return i > 0 ? AjaxResult.success("修改成功") : AjaxResult.error("修改失败");
} catch (Exception e) {
@ -184,13 +239,31 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult deleteCategory(Long categoryId) {
try {
int i = mapper.deleteCategory(categoryId);
// 创建一个列表来存储所有要删除的类别ID
List<Long> categoryIds = new ArrayList<>();
// 递归获取所有子类别ID
getAllSubcategories(categoryId, categoryIds);
// 将当前类别 ID 加入待删除列表
categoryIds.add(categoryId);
int i = mapper.deleteCategory(categoryIds);
return i > 0 ? AjaxResult.success("删除成功") : AjaxResult.error("删除失败");
} catch (Exception e) {
return AjaxResult.error("删除失败");
}
}
// 递归方法获取所有子类别ID
private void getAllSubcategories(Long categoryId, List<Long> categoryIds) {
// 获取当前类别的所有子类别ID
List<Long> subcategoryIds = mapper.getCategoryId(categoryId);
// 将子类别ID加入列表
categoryIds.addAll(subcategoryIds);
// 递归调用获取每个子类别的子类别
for (Long subcategoryId : subcategoryIds) {
getAllSubcategories(subcategoryId, categoryIds);
}
}
/**
* 根据 ID 查询数据集文件信息
*
@ -390,25 +463,28 @@ public class DataSetServiceImpl implements DataSetService {
public AjaxResult insertModel(AiModelEntity entity) {
try {
entity.setCreateBy(SecurityUtils.getUsername());
AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile());
if (ajaxResult.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) ajaxResult.get("data");
for (Map<String, String> map : data) {
entity.setModelAddress(map.get("url"));
if (ObjectUtils.isNotEmpty(entity.getModelFile())) {
AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile());
if (ajaxResult.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) ajaxResult.get("data");
for (Map<String, String> map : data) {
entity.setModelAddress(map.get("url"));
}
} else {
return AjaxResult.error("上传文件失败");
}
} else {
return AjaxResult.error("上传文件失败");
}
if (ObjectUtils.isNotEmpty(entity.getUserGuideFile())) {
AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile());
if (userGuide.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) userGuide.get("data");
for (Map<String, String> map : data) {
entity.setUserGuide(map.get("url"));
}
} else {
AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile());
if (userGuide.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) userGuide.get("data");
for (Map<String, String> map : data) {
entity.setUserGuide(map.get("url"));
return AjaxResult.error("上传文件失败");
}
} else {
return AjaxResult.error("上传文件失败");
}
int i = mapper.insertModel(entity);
return i > 0 ? AjaxResult.success("新增成功") : AjaxResult.error("新增失败");
@ -426,7 +502,7 @@ public class DataSetServiceImpl implements DataSetService {
@Override
public AjaxResult updateModel(AiModelEntity entity) {
try {
if (entity.getModelFile().length > 0) {
if (ObjectUtils.isNotEmpty(entity.getModelFile())) {
AjaxResult ajaxResult = remoteFileService.uploadFile(entity.getModelFile());
if (ajaxResult.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) ajaxResult.get("data");
@ -437,7 +513,7 @@ public class DataSetServiceImpl implements DataSetService {
return AjaxResult.error("上传文件失败");
}
}
if (entity.getUserGuideFile().length > 0) {
if (ObjectUtils.isNotEmpty(entity.getModelFile())) {
AjaxResult userGuide = remoteFileService.uploadFile(entity.getUserGuideFile());
if (userGuide.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) userGuide.get("data");
@ -551,4 +627,6 @@ public class DataSetServiceImpl implements DataSetService {
return AjaxResult.error("删除失败");
}
}
}

View File

@ -4,6 +4,7 @@ import cn.hutool.core.io.unit.DataUnit;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.bonus.ai.domain.DataSetFileEntity;
import com.bonus.ai.domain.vo.FaceResultVo;
import com.bonus.ai.domain.vo.FaceVo;
import com.bonus.ai.mapper.AiFaceRecognizeResultMapper;
@ -121,10 +122,13 @@ public class FaceServiceImpl implements FaceService {
}
private AjaxResult handleFileUpload(MultipartFile file, FaceVo face) throws Exception {
R<SysFile> upload = remoteFileService.upload(file);
if (upload.getCode() == 200) {
face.setFaceAddress(upload.getData().getUrl().replaceFirst("http://[^/]+", ""));
return faceMapper.insertFace(face) > 0 ? AjaxResult.success() : AjaxResult.error();
AjaxResult upload = remoteFileService.upload(file);
if (upload.isSuccess()) {
List<Map<String, String>> data = (List<Map<String, String>>) upload.get("data");
for (Map<String, String> map : data) {
face.setFaceAddress(map.get("url").replaceFirst("http://[^/]+", ""));
return faceMapper.insertFace(face) > 0 ? AjaxResult.success() : AjaxResult.error();
}
}
return AjaxResult.error();
}

View File

@ -217,6 +217,50 @@
AND ada.model_id =#{modelId} <!-- 根据 datasetName 进行模糊查询 -->
</if>
</select>
<select id="getCategoryId" resultType="java.lang.Long">
SELECT category_id
FROM ai_dataset_category
WHERE parent_id = #{categoryId}
AND del_flag = '0'
</select>
<select id="countByNameAndParentId" resultType="java.lang.Integer">
SELECT COUNT(*)
FROM ai_dataset_category
WHERE category_name = #{categoryName}
AND parent_id = #{parentId}
AND del_flag = '0'
<if test="categoryId != null and categoryId != ''">
AND category_id not in (#{categoryId}) <!-- 根据 datasetName 进行模糊查询 -->
</if>
</select>
<select id="getAncestorCategoryNames" resultType="com.bonus.ai.domain.DataSetCategoryEntity">
SELECT category_name AS categoryName,
category_id AS categoryId
FROM ai_dataset_category
WHERE category_id IN (SELECT parent_id
FROM ai_dataset_category
WHERE category_id = #{categoryId})
</select>
<select id="getAncestorCategoryName" resultType="com.bonus.ai.domain.DataSetCategoryEntity">
SELECT category_name AS categoryName,
category_id AS categoryId
FROM ai_dataset_category
WHERE category_id =#{parentId}
</select>
<select id="getCategoriesById" resultType="com.bonus.ai.domain.DataSetCategoryEntity">
SELECT adc.category_id AS categoryId,
adc.category_name AS categoryName,
adc.parent_id AS parentId,
adc.enabled AS enabled,
adc.created_by AS createdBy,
adc.description AS description,
adc.create_time AS createTime,
su.user_name AS createName
FROM ai_dataset_category adc
LEFT JOIN sys_user su ON su.user_id = adc.created_by
WHERE adc.del_flag = '0' AND adc.category_id not in (#{categoryId})
</select>
<!--
插入新的数据集到数据库。
@ -232,6 +276,7 @@
插入新的类别数据到数据库。
具体实现尚未提供。
-->
<insert id="insertCategory">
INSERT INTO ai_dataset_category (category_name, parent_id, enabled, created_by, description)
VALUES (#{categoryName}, #{parentId}, #{enabled}, #{createdBy}, #{description})
@ -250,11 +295,11 @@
</insert>
<insert id="insertModel">
INSERT INTO ai_model (model_name, model_version, sub_task_type_id,
INSERT INTO ai_model (model_name, model_version,
infer_language, model_format, deploy_requirement, model_address,
user_guide, description, create_by,
algorithm, dataSetId, model_type)
VALUES (#{modelName}, #{modelVersion}, #{subTaskTypeId}, #{inferLanguage}, #{modelFormat},
VALUES (#{modelName}, #{modelVersion}, #{inferLanguage}, #{modelFormat},
#{deployRequirement}, #{modelAddress}, #{userGuide}, #{description}, #{createBy}, #{algorithm},
#{dataSetId}, #{modelType});
@ -306,7 +351,10 @@
<update id="deleteCategory">
UPDATE ai_dataset_category
SET del_flag = '1'
WHERE category_id = #{categoryId}
WHERE category_id IN
<foreach item="categoryId" collection="categoryIds" open="(" separator="," close=")">
#{categoryId}
</foreach>
</update>
<update id="updateFile">
UPDATE ai_dataset_file