Langchain-Chatchat/server/utils.py

78 lines
3.6 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
import pydantic
from pydantic import BaseModel
from typing import List
import torch
from fastapi_offline import FastAPIOffline
import fastapi_offline
from pathlib import Path
# patch fastapi_offline to use local static assests
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
2023-07-27 23:22:07 +08:00
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field("success", description="HTTP status message")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
}
}
class ListResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of names")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}
class ChatMessage(BaseModel):
question: str = pydantic.Field(..., description="Question text")
response: str = pydantic.Field(..., description="Response text")
history: List[List[str]] = pydantic.Field(..., description="History text")
source_documents: List[str] = pydantic.Field(
..., description="List of source documents and their scores"
)
class Config:
schema_extra = {
"example": {
"question": "工伤保险如何办理?",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
"history": [
[
"工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
]
],
"source_documents": [
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
"出处 [2] ...",
"出处 [3] ...",
],
}
}
def torch_gc():
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif torch.backends.mps.is_available():
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")