update import pkgs and format

This commit is contained in:
imClumsyPanda 2023-08-10 21:26:05 +08:00
parent 22c8f277bf
commit 8a4d9168fa
22 changed files with 192 additions and 190 deletions

View File

@ -1,56 +0,0 @@
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain import LLMChain
from langchain.llms import OpenAI
from configs.model_config import *
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.callbacks import StreamlitCallbackHandler
with open("../knowledge_base/samples/content/test.txt") as f:
state_of_the_union = f.read()
# TODO: define params
# text_splitter = MyTextSplitter()
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
texts = text_splitter.split_text(state_of_the_union)
# TODO: define params
# embeddings = MyEmbeddings()
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
docsearch = Chroma.from_texts(
texts,
embeddings,
metadatas=[{"source": str(i)} for i in range(len(texts))]
).as_retriever()
# test
query = "什么是Prompt工程"
docs = docsearch.get_relevant_documents(query)
# print(docs)
# prompt_template = PROMPT_TEMPLATE
llm = OpenAI(model_name=LLM_MODEL,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
streaming=True)
# print(PROMPT)
prompt = PromptTemplate(input_variables=["input"], template="{input}")
chain = LLMChain(prompt=prompt, llm=llm)
resp = chain("你好")
for x in resp:
print(x)
PROMPT = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
response = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
for x in response:
print(response["output_text"])

View File

@ -1,17 +0,0 @@
from typing import List
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
class MyEmbeddings(Embeddings, BaseModel):
size: int
def _get_embedding(self) -> List[float]:
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:
return self._get_embedding()

View File

@ -1 +0,0 @@
from .MyEmbeddings import MyEmbeddings

View File

@ -3,7 +3,6 @@ openai
sentence_transformers sentence_transformers
fschat==0.2.20 fschat==0.2.20
transformers transformers
transformers_stream_generator
torch~=2.0.0 torch~=2.0.0
fastapi~=0.99.1 fastapi~=0.99.1
fastapi-offline fastapi-offline

View File

@ -1,6 +1,7 @@
import nltk import nltk
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN

View File

@ -8,7 +8,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional from typing import List
from server.chat.utils import History from server.chat.utils import History
@ -21,6 +21,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
), ),
): ):
history = [History(**h) if isinstance(h, dict) else h for h in history] history = [History(**h) if isinstance(h, dict) else h for h in history]
async def chat_iterator(query: str, async def chat_iterator(query: str,
history: List[History] = [], history: List[History] = [],
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:

View File

@ -72,7 +72,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
async for token in callback.aiter(): async for token in callback.aiter():
# Use server-sent-events to stream the response # Use server-sent-events to stream the response
yield json.dumps({"answer": token, yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False) "docs": source_documents},
ensure_ascii=False)
else: else:
answer = "" answer = ""
async for token in callback.aiter(): async for token in callback.aiter():

View File

@ -1,6 +1,5 @@
from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List, Dict from typing import List
import openai import openai
from configs.model_config import llm_model_dict, LLM_MODEL from configs.model_config import llm_model_dict, LLM_MODEL
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -111,7 +111,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
async for token in callback.aiter(): async for token in callback.aiter():
# Use server-sent-events to stream the response # Use server-sent-events to stream the response
yield json.dumps({"answer": token, yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False) "docs": source_documents},
ensure_ascii=False)
else: else:
answer = "" answer = ""
async for token in callback.aiter(): async for token in callback.aiter():

View File

@ -1,9 +1,6 @@
from functools import wraps from functools import wraps
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager from contextlib import contextmanager
from server.db.base import SessionLocal
from server.db.base import engine, SessionLocal
@contextmanager @contextmanager

View File

@ -2,13 +2,11 @@ import os
import urllib import urllib
from fastapi import File, Form, Body, UploadFile from fastapi import File, Form, Body, UploadFile
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (get_file_path, validate_kb_name) from server.knowledge_base.utils import validate_kb_name
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType
from typing import Union
async def list_docs( async def list_docs(
@ -26,7 +24,7 @@ async def list_docs(
return ListResponse(data=all_doc_names) return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(description="上传文件"), async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"), override: bool = Form(False, description="覆盖已有文件"),
): ):

View File

@ -8,8 +8,8 @@ from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \ from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
list_docs_from_db list_docs_from_db
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile) from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
from typing import List, Union from typing import List, Union

View File

@ -6,8 +6,8 @@ from langchain.vectorstores import PGVector
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType from server.knowledge_base.kb_service.base import SupportedVSType, KBService
from server.knowledge_base.utils import KBService, load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile
class PGKBService(KBService): class PGKBService(KBService):

View File

@ -10,7 +10,6 @@ from configs.model_config import (
from functools import lru_cache from functools import lru_cache
import sys import sys
from text_splitter import zh_title_enhance from text_splitter import zh_title_enhance
from langchain.document_loaders import UnstructuredFileLoader
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:

View File

@ -4,7 +4,6 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
import asyncio
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
controller_port = 20001 controller_port = 20001

View File

@ -6,6 +6,7 @@
""" """
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import subprocess import subprocess
@ -19,7 +20,6 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# ------multi worker----------------- # ------multi worker-----------------
parser.add_argument('--model-path-address', parser.add_argument('--model-path-address',
@ -160,7 +160,8 @@ server_args = ["server-host","server-port","allow-credentials","api-keys",
args = parser.parse_args() args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found # 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"}) args = argparse.Namespace(**vars(args),
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
if args.gpus: if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus: if len(args.gpus.split(",")) < args.num_gpus:
@ -208,6 +209,7 @@ def string_args(args,args_list):
return args_str return args_str
def launch_worker(item): def launch_worker(item):
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
# 先分割model-path-address,在传到string_args中分析参数 # 先分割model-path-address,在传到string_args中分析参数
@ -241,5 +243,6 @@ def launch_all():
subprocess.run(server_sh, shell=True, check=True) subprocess.run(server_sh, shell=True, check=True)
subprocess.run(server_check_sh, shell=True, check=True) subprocess.run(server_check_sh, shell=True, check=True)
if __name__ == "__main__": if __name__ == "__main__":
launch_all() launch_all()

View File

@ -5,6 +5,7 @@ python llm_api_shutdown.py --serve all
""" """
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import subprocess import subprocess

View File

@ -48,15 +48,22 @@ class ChatMessage(BaseModel):
schema_extra = { schema_extra = {
"example": { "example": {
"question": "工伤保险如何办理?", "question": "工伤保险如何办理?",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
"2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n"
"3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n"
"4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n"
"5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n"
"6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
"history": [ "history": [
[ [
"工伤保险是什么?", "工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,"
"由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
] ]
], ],
"source_documents": [ "source_documents": [
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。", "出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx\n\n\t"
"( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
"出处 [2] ...", "出处 [2] ...",
"出处 [3] ...", "出处 [3] ...",
], ],

View File

@ -1,18 +0,0 @@
from langchain.text_splitter import TextSplitter, _split_text_with_regex
from typing import Any, List
class MyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)

View File

@ -1,2 +1,3 @@
from .MyTextSplitter import MyTextSplitter from .chinese_text_splitter import ChineseTextSplitter
from .ali_text_splitter import AliTextSplitter
from .zh_title_enhance import zh_title_enhance from .zh_title_enhance import zh_title_enhance

View File

@ -0,0 +1,27 @@
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
class AliTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
# use_document_segmentation参数指定是否用语义切分文档此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base论文见https://arxiv.org/abs/2107.09278
# 如果使用模型进行文档语义切分那么需要安装modelscope[nlp]pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 考虑到使用了三个模型可能对于低配置gpu不太友好因此这里将模型load进cpu计算有需要的话可以替换device为自己的显卡id
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
from modelscope.pipelines import pipeline
p = pipeline(
task="document-segmentation",
model='damo/nlp_bert_document-segmentation_chinese-base',
device="cpu")
result = p(documents=text)
sent_list = [i for i in result["text"].split("\n\t") if i]
return sent_list

View File

@ -0,0 +1,60 @@
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from configs.model_config import SENTENCE_SIZE
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
return ls