diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py deleted file mode 100644 index cad30b0..0000000 --- a/chains/local_doc_qa.py +++ /dev/null @@ -1,56 +0,0 @@ -from langchain.vectorstores import Chroma -from langchain.prompts import PromptTemplate -from langchain.chains.question_answering import load_qa_chain -from langchain import LLMChain -from langchain.llms import OpenAI -from configs.model_config import * -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.text_splitter import CharacterTextSplitter -from langchain.callbacks import StreamlitCallbackHandler - -with open("../knowledge_base/samples/content/test.txt") as f: - state_of_the_union = f.read() - -# TODO: define params -# text_splitter = MyTextSplitter() -text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200) -texts = text_splitter.split_text(state_of_the_union) - -# TODO: define params -# embeddings = MyEmbeddings() -embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], - model_kwargs={'device': EMBEDDING_DEVICE}) - -docsearch = Chroma.from_texts( - texts, - embeddings, - metadatas=[{"source": str(i)} for i in range(len(texts))] -).as_retriever() - -# test -query = "什么是Prompt工程" -docs = docsearch.get_relevant_documents(query) -# print(docs) - -# prompt_template = PROMPT_TEMPLATE - -llm = OpenAI(model_name=LLM_MODEL, - openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - streaming=True) - -# print(PROMPT) -prompt = PromptTemplate(input_variables=["input"], template="{input}") -chain = LLMChain(prompt=prompt, llm=llm) -resp = chain("你好") -for x in resp: - print(x) - -PROMPT = PromptTemplate( - template=PROMPT_TEMPLATE, - input_variables=["context", "question"] -) -chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT) -response = chain({"input_documents": docs, "question": query}, return_only_outputs=False) -for x in response: - print(response["output_text"]) \ No newline at end of file diff --git a/embeddings/MyEmbeddings.py b/embeddings/MyEmbeddings.py deleted file mode 100644 index 86fc12e..0000000 --- a/embeddings/MyEmbeddings.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import List -import numpy as np -from pydantic import BaseModel -from langchain.embeddings.base import Embeddings - - -class MyEmbeddings(Embeddings, BaseModel): - size: int - - def _get_embedding(self) -> List[float]: - return list(np.random.normal(size=self.size)) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [self._get_embedding() for _ in texts] - - def embed_query(self, text: str) -> List[float]: - return self._get_embedding() diff --git a/embeddings/__init__.py b/embeddings/__init__.py index 9c8bef8..e69de29 100644 --- a/embeddings/__init__.py +++ b/embeddings/__init__.py @@ -1 +0,0 @@ -from .MyEmbeddings import MyEmbeddings diff --git a/requirements.txt b/requirements.txt index 5b366c0..b79f3e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ openai sentence_transformers fschat==0.2.20 transformers -transformers_stream_generator torch~=2.0.0 fastapi~=0.99.1 fastapi-offline diff --git a/server/api.py b/server/api.py index b83c838..e5d9830 100644 --- a/server/api.py +++ b/server/api.py @@ -1,6 +1,7 @@ import nltk import sys import os + sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN @@ -13,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, - update_doc, recreate_vector_store) + update_doc, recreate_vector_store) from server.utils import BaseResponse, ListResponse nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path diff --git a/server/chat/chat.py b/server/chat/chat.py index fd7e0db..b864609 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -8,19 +8,20 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional +from typing import List from server.chat.utils import History def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}]] + ), ): history = [History(**h) if isinstance(h, dict) else h for h in history] + async def chat_iterator(query: str, history: List[History] = [], ) -> AsyncIterable[str]: diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 7eb7be2..d1adf59 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -20,13 +20,13 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), stream: bool = Body(False, description="流式输出"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) @@ -72,14 +72,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token, - "docs": source_documents}, ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) else: answer = "" async for token in callback.aiter(): answer += token yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) await task diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 2f55895..7b69265 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,6 +1,5 @@ -from fastapi import Body from fastapi.responses import StreamingResponse -from typing import List, Dict +from typing import List import openai from configs.model_config import llm_model_dict, LLM_MODEL from pydantic import BaseModel diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 04cf933..15834d0 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -61,14 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -111,7 +111,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token, - "docs": source_documents}, ensure_ascii=False) + "docs": source_documents}, + ensure_ascii=False) else: answer = "" async for token in callback.aiter(): diff --git a/server/db/session.py b/server/db/session.py index 9ac50e6..51ec55d 100644 --- a/server/db/session.py +++ b/server/db/session.py @@ -1,9 +1,6 @@ from functools import wraps - -from sqlalchemy.orm import sessionmaker from contextlib import contextmanager - -from server.db.base import engine, SessionLocal +from server.db.base import SessionLocal @contextmanager diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index bfdfd87..0249691 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -2,13 +2,11 @@ import os import urllib from fastapi import File, Form, Body, UploadFile from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import (get_file_path, validate_kb_name) +from server.knowledge_base.utils import validate_kb_name from fastapi.responses import StreamingResponse import json from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.kb_service.base import KBServiceFactory -from server.knowledge_base.kb_service.base import SupportedVSType -from typing import Union async def list_docs( @@ -26,7 +24,7 @@ async def list_docs( return ListResponse(data=all_doc_names) -async def upload_doc(file: UploadFile = File(description="上传文件"), +async def upload_doc(file: UploadFile = File(..., description="上传文件"), knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), override: bool = Form(False, description="覆盖已有文件"), ): diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ce6909d..a9b102f 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -8,8 +8,8 @@ from langchain.docstore.document import Document from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \ list_docs_from_db -from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K, - embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL) +from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, + EMBEDDING_DEVICE, EMBEDDING_MODEL) from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile) from typing import List, Union diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 33fb38c..7e4bfef 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -6,8 +6,8 @@ from langchain.vectorstores import PGVector from sqlalchemy import text from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base.kb_service.base import SupportedVSType -from server.knowledge_base.utils import KBService, load_embeddings, KnowledgeFile +from server.knowledge_base.kb_service.base import SupportedVSType, KBService +from server.knowledge_base.utils import load_embeddings, KnowledgeFile class PGKBService(KBService): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 004c748..cc10646 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -10,7 +10,6 @@ from configs.model_config import ( from functools import lru_cache import sys from text_splitter import zh_title_enhance -from langchain.document_loaders import UnstructuredFileLoader def validate_kb_name(knowledge_base_id: str) -> bool: diff --git a/server/llm_api.py b/server/llm_api.py index 6cc6a6c..d26a935 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,7 +4,6 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -import asyncio host_ip = "0.0.0.0" controller_port = 20001 diff --git a/server/llm_api_launch.py b/server/llm_api_launch.py index d7d1900..a4ba256 100644 --- a/server/llm_api_launch.py +++ b/server/llm_api_launch.py @@ -5,7 +5,8 @@ """ import sys -import os +import os + sys.path.append(os.path.dirname(os.path.dirname(__file__))) import subprocess @@ -19,15 +20,14 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) - parser = argparse.ArgumentParser() -#------multi worker----------------- +# ------multi worker----------------- parser.add_argument('--model-path-address', default="THUDM/chatglm2-6b@localhost@20002", nargs="+", type=str, help="model path, host, and port, formatted as model-path@host@path") -#---------------controller------------------------- +# ---------------controller------------------------- parser.add_argument("--controller-host", type=str, default="localhost") parser.add_argument("--controller-port", type=int, default=21001) @@ -37,9 +37,9 @@ parser.add_argument( choices=["lottery", "shortest_queue"], default="shortest_queue", ) -controller_args = ["controller-host","controller-port","dispatch-method"] +controller_args = ["controller-host", "controller-port", "dispatch-method"] -#----------------------worker------------------------------------------ +# ----------------------worker------------------------------------------ parser.add_argument("--worker-host", type=str, default="localhost") parser.add_argument("--worker-port", type=int, default=21002) @@ -125,15 +125,15 @@ parser.add_argument("--stream-interval", type=int, default=2) parser.add_argument("--no-register", action="store_true") worker_args = [ - "worker-host","worker-port", - "model-path","revision","device","gpus","num-gpus", - "max-gpu-memory","load-8bit","cpu-offloading", - "gptq-ckpt","gptq-wbits","gptq-groupsize", - "gptq-act-order","model-names","limit-worker-concurrency", - "stream-interval","no-register", - "controller-address" - ] -#-----------------openai server--------------------------- + "worker-host", "worker-port", + "model-path", "revision", "device", "gpus", "num-gpus", + "max-gpu-memory", "load-8bit", "cpu-offloading", + "gptq-ckpt", "gptq-wbits", "gptq-groupsize", + "gptq-act-order", "model-names", "limit-worker-concurrency", + "stream-interval", "no-register", + "controller-address" +] +# -----------------openai server--------------------------- parser.add_argument("--server-host", type=str, default="127.0.0.1", help="host name") parser.add_argument("--server-port", type=int, default=8888, help="port number") @@ -154,13 +154,14 @@ parser.add_argument( type=lambda s: s.split(","), help="Optional list of comma separated API keys", ) -server_args = ["server-host","server-port","allow-credentials","api-keys", +server_args = ["server-host", "server-port", "allow-credentials", "api-keys", "controller-address" ] args = parser.parse_args() # 必须要加http//:,否则InvalidSchema: No connection adapters were found -args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"}) +args = argparse.Namespace(**vars(args), + **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) if args.gpus: if len(args.gpus.split(",")) < args.num_gpus: @@ -173,10 +174,10 @@ if args.gpus: # 1, 命令行选项 # 2,LOG_PATH # 3, log的文件名 -base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" +base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" # 0 log_path -#! 1 log的文件名,必须与bash_launch_sh一致 +# ! 1 log的文件名,必须与bash_launch_sh一致 # 2 controller, worker, openai_api_server base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do sleep 1s; @@ -185,20 +186,20 @@ base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ]; echo '{2} running' """ -def string_args(args,args_list): +def string_args(args, args_list): """将args中的key转化为字符串""" args_str = "" for key, value in args._get_kwargs(): # args._get_kwargs中的key以_为分隔符,先转换,再判断是否在指定的args列表中 - key = key.replace("_","-") + key = key.replace("_", "-") if key not in args_list: - continue + continue # fastchat中port,host没有前缀,去除前缀 - key = key.split("-")[-1] if re.search("port|host",key) else key + key = key.split("-")[-1] if re.search("port|host", key) else key if not value: pass # 1==True -> True - elif isinstance(value,bool) and value == True: + elif isinstance(value, bool) and value == True: args_str += f" --{key} " elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set): value = " ".join(value) @@ -208,38 +209,40 @@ def string_args(args,args_list): return args_str + def launch_worker(item): - log_name = item.split("/")[-1].split("\\")[-1].replace("-","_").replace("@","_").replace(".","_") - # 先分割model-path-address,在传到string_args中分析参数 - args.model_path,args.worker_host, args.worker_port = item.split("@") - print("*"*80) - worker_str_args = string_args(args,worker_args) - print(worker_str_args) - worker_sh = base_launch_sh.format("model_worker",worker_str_args,LOG_PATH,f"worker_{log_name}") - worker_check_sh = base_check_sh.format(LOG_PATH,f"worker_{log_name}","model_worker") - subprocess.run(worker_sh,shell=True,check=True) - subprocess.run(worker_check_sh,shell=True,check=True) + log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") + # 先分割model-path-address,在传到string_args中分析参数 + args.model_path, args.worker_host, args.worker_port = item.split("@") + print("*" * 80) + worker_str_args = string_args(args, worker_args) + print(worker_str_args) + worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}") + worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker") + subprocess.run(worker_sh, shell=True, check=True) + subprocess.run(worker_check_sh, shell=True, check=True) def launch_all(): - controller_str_args = string_args(args,controller_args) - controller_sh = base_launch_sh.format("controller",controller_str_args,LOG_PATH,"controller") - controller_check_sh = base_check_sh.format(LOG_PATH,"controller","controller") - subprocess.run(controller_sh,shell=True,check=True) - subprocess.run(controller_check_sh,shell=True,check=True) + controller_str_args = string_args(args, controller_args) + controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller") + controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller") + subprocess.run(controller_sh, shell=True, check=True) + subprocess.run(controller_check_sh, shell=True, check=True) if isinstance(args.model_path_address, str): launch_worker(args.model_path_address) else: - for idx,item in enumerate(args.model_path_address): + for idx, item in enumerate(args.model_path_address): print(f"开始加载第{idx}个模型:{item}") launch_worker(item) - - server_str_args = string_args(args,server_args) - server_sh = base_launch_sh.format("openai_api_server",server_str_args,LOG_PATH,"openai_api_server") - server_check_sh = base_check_sh.format(LOG_PATH,"openai_api_server","openai_api_server") - subprocess.run(server_sh,shell=True,check=True) - subprocess.run(server_check_sh,shell=True,check=True) + + server_str_args = string_args(args, server_args) + server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server") + server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server") + subprocess.run(server_sh, shell=True, check=True) + subprocess.run(server_check_sh, shell=True, check=True) + if __name__ == "__main__": - launch_all() \ No newline at end of file + launch_all() diff --git a/server/llm_api_shutdown.py b/server/llm_api_shutdown.py index bc65a9a..1ac1404 100644 --- a/server/llm_api_shutdown.py +++ b/server/llm_api_shutdown.py @@ -4,14 +4,15 @@ python llm_api_shutdown.py --serve all 可选"all","controller","model_worker","openai_api_server", all表示停止所有服务 """ import sys -import os +import os + sys.path.append(os.path.dirname(os.path.dirname(__file__))) import subprocess import argparse parser = argparse.ArgumentParser() -parser.add_argument("--serve",choices=["all","controller","model_worker","openai_api_server"],default="all") +parser.add_argument("--serve", choices=["all", "controller", "model_worker", "openai_api_server"], default="all") args = parser.parse_args() @@ -23,5 +24,5 @@ else: serve = f".{args.serve}" shell_script = base_shell.format(serve) -subprocess.run(shell_script,shell=True,check=True) -print(f"llm api sever --{args.serve} has been shutdown!") \ No newline at end of file +subprocess.run(shell_script, shell=True, check=True) +print(f"llm api sever --{args.serve} has been shutdown!") diff --git a/server/utils.py b/server/utils.py index e200aa7..7228930 100644 --- a/server/utils.py +++ b/server/utils.py @@ -48,15 +48,22 @@ class ChatMessage(BaseModel): schema_extra = { "example": { "question": "工伤保险如何办理?", - "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", + "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n" + "2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n" + "3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n" + "4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n" + "5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n" + "6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", "history": [ [ "工伤保险是什么?", - "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", + "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费," + "由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", ] ], "source_documents": [ - "出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。", + "出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t" + "( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。", "出处 [2] ...", "出处 [3] ...", ], diff --git a/text_splitter/MyTextSplitter.py b/text_splitter/MyTextSplitter.py deleted file mode 100644 index 15b179c..0000000 --- a/text_splitter/MyTextSplitter.py +++ /dev/null @@ -1,18 +0,0 @@ -from langchain.text_splitter import TextSplitter, _split_text_with_regex -from typing import Any, List - - -class MyTextSplitter(TextSplitter): - """Implementation of splitting text that looks at characters.""" - - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: - """Create a new TextSplitter.""" - super().__init__(**kwargs) - self._separator = separator - - def split_text(self, text: str) -> List[str]: - """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. - splits = _split_text_with_regex(text, self._separator, self._keep_separator) - _separator = "" if self._keep_separator else self._separator - return self._merge_splits(splits, _separator) \ No newline at end of file diff --git a/text_splitter/__init__.py b/text_splitter/__init__.py index 1c4b665..f059ccb 100644 --- a/text_splitter/__init__.py +++ b/text_splitter/__init__.py @@ -1,2 +1,3 @@ -from .MyTextSplitter import MyTextSplitter +from .chinese_text_splitter import ChineseTextSplitter +from .ali_text_splitter import AliTextSplitter from .zh_title_enhance import zh_title_enhance \ No newline at end of file diff --git a/text_splitter/ali_text_splitter.py b/text_splitter/ali_text_splitter.py new file mode 100644 index 0000000..1e62eb6 --- /dev/null +++ b/text_splitter/ali_text_splitter.py @@ -0,0 +1,27 @@ +from langchain.text_splitter import CharacterTextSplitter +import re +from typing import List + + +class AliTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + + def split_text(self, text: str) -> List[str]: + # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 + # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub('\s', " ", text) + text = re.sub("\n\n", "", text) + from modelscope.pipelines import pipeline + + p = pipeline( + task="document-segmentation", + model='damo/nlp_bert_document-segmentation_chinese-base', + device="cpu") + result = p(documents=text) + sent_list = [i for i in result["text"].split("\n\t") if i] + return sent_list diff --git a/text_splitter/chinese_text_splitter.py b/text_splitter/chinese_text_splitter.py new file mode 100644 index 0000000..b6e7940 --- /dev/null +++ b/text_splitter/chinese_text_splitter.py @@ -0,0 +1,60 @@ +from langchain.text_splitter import CharacterTextSplitter +import re +from typing import List +from configs.model_config import SENTENCE_SIZE + + +class ChineseTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + self.sentence_size = sentence_size + + def split_text1(self, text: str) -> List[str]: + if self.pdf: + text = re.sub(r"\n{3,}", "\n", text) + text = re.sub('\s', ' ', text) + text = text.replace("\n\n", "") + sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; + sent_list = [] + for ele in sent_sep_pattern.split(text): + if sent_sep_pattern.match(ele) and sent_list: + sent_list[-1] += ele + elif ele: + sent_list.append(ele) + return sent_list + + def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub('\s', " ", text) + text = re.sub("\n\n", "", text) + + text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符 + text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 + text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 + text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text) + # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 + text = text.rstrip() # 段尾如果有多余的\n就去掉它 + # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 + ls = [i for i in text.split("\n") if i] + for ele in ls: + if len(ele) > self.sentence_size: + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele) + ele1_ls = ele1.split("\n") + for ele_ele1 in ele1_ls: + if len(ele_ele1) > self.sentence_size: + ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1) + ele2_ls = ele_ele2.split("\n") + for ele_ele2 in ele2_ls: + if len(ele_ele2) > self.sentence_size: + ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2) + ele2_id = ele2_ls.index(ele_ele2) + ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[ + ele2_id + 1:] + ele_id = ele1_ls.index(ele_ele1) + ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] + + id = ls.index(ele) + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] + return ls