update import pkgs and format

This commit is contained in:
imClumsyPanda 2023-08-10 21:26:05 +08:00
parent 22c8f277bf
commit 8a4d9168fa
22 changed files with 192 additions and 190 deletions

View File

@ -1,56 +0,0 @@
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain import LLMChain
from langchain.llms import OpenAI
from configs.model_config import *
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.callbacks import StreamlitCallbackHandler
with open("../knowledge_base/samples/content/test.txt") as f:
state_of_the_union = f.read()
# TODO: define params
# text_splitter = MyTextSplitter()
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
texts = text_splitter.split_text(state_of_the_union)
# TODO: define params
# embeddings = MyEmbeddings()
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
docsearch = Chroma.from_texts(
texts,
embeddings,
metadatas=[{"source": str(i)} for i in range(len(texts))]
).as_retriever()
# test
query = "什么是Prompt工程"
docs = docsearch.get_relevant_documents(query)
# print(docs)
# prompt_template = PROMPT_TEMPLATE
llm = OpenAI(model_name=LLM_MODEL,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
streaming=True)
# print(PROMPT)
prompt = PromptTemplate(input_variables=["input"], template="{input}")
chain = LLMChain(prompt=prompt, llm=llm)
resp = chain("你好")
for x in resp:
print(x)
PROMPT = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
response = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
for x in response:
print(response["output_text"])

View File

@ -1,17 +0,0 @@
from typing import List
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
class MyEmbeddings(Embeddings, BaseModel):
size: int
def _get_embedding(self) -> List[float]:
return list(np.random.normal(size=self.size))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self._get_embedding() for _ in texts]
def embed_query(self, text: str) -> List[float]:
return self._get_embedding()

View File

@ -1 +0,0 @@
from .MyEmbeddings import MyEmbeddings

View File

@ -3,7 +3,6 @@ openai
sentence_transformers
fschat==0.2.20
transformers
transformers_stream_generator
torch~=2.0.0
fastapi~=0.99.1
fastapi-offline

View File

@ -1,6 +1,7 @@
import nltk
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
@ -13,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store)
update_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path

View File

@ -8,19 +8,20 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from typing import List
from server.chat.utils import History
def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
):
history = [History(**h) if isinstance(h, dict) else h for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
) -> AsyncIterable[str]:

View File

@ -20,13 +20,13 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
@ -72,14 +72,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
await task

View File

@ -1,6 +1,5 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from typing import List, Dict
from typing import List
import openai
from configs.model_config import llm_model_dict, LLM_MODEL
from pydantic import BaseModel

View File

@ -61,14 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -111,7 +111,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents}, ensure_ascii=False)
"docs": source_documents},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():

View File

@ -1,9 +1,6 @@
from functools import wraps
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from server.db.base import engine, SessionLocal
from server.db.base import SessionLocal
@contextmanager

View File

@ -2,13 +2,11 @@ import os
import urllib
from fastapi import File, Form, Body, UploadFile
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (get_file_path, validate_kb_name)
from server.knowledge_base.utils import validate_kb_name
from fastapi.responses import StreamingResponse
import json
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.kb_service.base import SupportedVSType
from typing import Union
async def list_docs(
@ -26,7 +24,7 @@ async def list_docs(
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(description="上传文件"),
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
):

View File

@ -8,8 +8,8 @@ from langchain.docstore.document import Document
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
list_docs_from_db
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
from typing import List, Union

View File

@ -6,8 +6,8 @@ from langchain.vectorstores import PGVector
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType
from server.knowledge_base.utils import KBService, load_embeddings, KnowledgeFile
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
class PGKBService(KBService):

View File

@ -10,7 +10,6 @@ from configs.model_config import (
from functools import lru_cache
import sys
from text_splitter import zh_title_enhance
from langchain.document_loaders import UnstructuredFileLoader
def validate_kb_name(knowledge_base_id: str) -> bool:

View File

@ -4,7 +4,6 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
import asyncio
host_ip = "0.0.0.0"
controller_port = 20001

View File

@ -5,7 +5,8 @@
"""
import sys
import os
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import subprocess
@ -19,15 +20,14 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
parser = argparse.ArgumentParser()
#------multi worker-----------------
# ------multi worker-----------------
parser.add_argument('--model-path-address',
default="THUDM/chatglm2-6b@localhost@20002",
nargs="+",
type=str,
help="model path, host, and port, formatted as model-path@host@path")
#---------------controller-------------------------
# ---------------controller-------------------------
parser.add_argument("--controller-host", type=str, default="localhost")
parser.add_argument("--controller-port", type=int, default=21001)
@ -37,9 +37,9 @@ parser.add_argument(
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
controller_args = ["controller-host","controller-port","dispatch-method"]
controller_args = ["controller-host", "controller-port", "dispatch-method"]
#----------------------worker------------------------------------------
# ----------------------worker------------------------------------------
parser.add_argument("--worker-host", type=str, default="localhost")
parser.add_argument("--worker-port", type=int, default=21002)
@ -125,15 +125,15 @@ parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
worker_args = [
"worker-host","worker-port",
"model-path","revision","device","gpus","num-gpus",
"max-gpu-memory","load-8bit","cpu-offloading",
"gptq-ckpt","gptq-wbits","gptq-groupsize",
"gptq-act-order","model-names","limit-worker-concurrency",
"stream-interval","no-register",
"controller-address"
]
#-----------------openai server---------------------------
"worker-host", "worker-port",
"model-path", "revision", "device", "gpus", "num-gpus",
"max-gpu-memory", "load-8bit", "cpu-offloading",
"gptq-ckpt", "gptq-wbits", "gptq-groupsize",
"gptq-act-order", "model-names", "limit-worker-concurrency",
"stream-interval", "no-register",
"controller-address"
]
# -----------------openai server---------------------------
parser.add_argument("--server-host", type=str, default="127.0.0.1", help="host name")
parser.add_argument("--server-port", type=int, default=8888, help="port number")
@ -154,13 +154,14 @@ parser.add_argument(
type=lambda s: s.split(","),
help="Optional list of comma separated API keys",
)
server_args = ["server-host","server-port","allow-credentials","api-keys",
server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
"controller-address"
]
args = parser.parse_args()
# 必须要加http//:否则InvalidSchema: No connection adapters were found
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
args = argparse.Namespace(**vars(args),
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
@ -173,10 +174,10 @@ if args.gpus:
# 1, 命令行选项
# 2,LOG_PATH
# 3, log的文件名
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
# 0 log_path
#! 1 log的文件名必须与bash_launch_sh一致
# ! 1 log的文件名必须与bash_launch_sh一致
# 2 controller, worker, openai_api_server
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
sleep 1s;
@ -185,20 +186,20 @@ base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];
echo '{2} running' """
def string_args(args,args_list):
def string_args(args, args_list):
"""将args中的key转化为字符串"""
args_str = ""
for key, value in args._get_kwargs():
# args._get_kwargs中的key以_为分隔符,先转换再判断是否在指定的args列表中
key = key.replace("_","-")
key = key.replace("_", "-")
if key not in args_list:
continue
continue
# fastchat中port,host没有前缀去除前缀
key = key.split("-")[-1] if re.search("port|host",key) else key
key = key.split("-")[-1] if re.search("port|host", key) else key
if not value:
pass
# 1==True -> True
elif isinstance(value,bool) and value == True:
elif isinstance(value, bool) and value == True:
args_str += f" --{key} "
elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set):
value = " ".join(value)
@ -208,38 +209,40 @@ def string_args(args,args_list):
return args_str
def launch_worker(item):
log_name = item.split("/")[-1].split("\\")[-1].replace("-","_").replace("@","_").replace(".","_")
# 先分割model-path-address,在传到string_args中分析参数
args.model_path,args.worker_host, args.worker_port = item.split("@")
print("*"*80)
worker_str_args = string_args(args,worker_args)
print(worker_str_args)
worker_sh = base_launch_sh.format("model_worker",worker_str_args,LOG_PATH,f"worker_{log_name}")
worker_check_sh = base_check_sh.format(LOG_PATH,f"worker_{log_name}","model_worker")
subprocess.run(worker_sh,shell=True,check=True)
subprocess.run(worker_check_sh,shell=True,check=True)
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
# 先分割model-path-address,在传到string_args中分析参数
args.model_path, args.worker_host, args.worker_port = item.split("@")
print("*" * 80)
worker_str_args = string_args(args, worker_args)
print(worker_str_args)
worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker")
subprocess.run(worker_sh, shell=True, check=True)
subprocess.run(worker_check_sh, shell=True, check=True)
def launch_all():
controller_str_args = string_args(args,controller_args)
controller_sh = base_launch_sh.format("controller",controller_str_args,LOG_PATH,"controller")
controller_check_sh = base_check_sh.format(LOG_PATH,"controller","controller")
subprocess.run(controller_sh,shell=True,check=True)
subprocess.run(controller_check_sh,shell=True,check=True)
controller_str_args = string_args(args, controller_args)
controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
subprocess.run(controller_sh, shell=True, check=True)
subprocess.run(controller_check_sh, shell=True, check=True)
if isinstance(args.model_path_address, str):
launch_worker(args.model_path_address)
else:
for idx,item in enumerate(args.model_path_address):
for idx, item in enumerate(args.model_path_address):
print(f"开始加载第{idx}个模型:{item}")
launch_worker(item)
server_str_args = string_args(args,server_args)
server_sh = base_launch_sh.format("openai_api_server",server_str_args,LOG_PATH,"openai_api_server")
server_check_sh = base_check_sh.format(LOG_PATH,"openai_api_server","openai_api_server")
subprocess.run(server_sh,shell=True,check=True)
subprocess.run(server_check_sh,shell=True,check=True)
server_str_args = string_args(args, server_args)
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
subprocess.run(server_sh, shell=True, check=True)
subprocess.run(server_check_sh, shell=True, check=True)
if __name__ == "__main__":
launch_all()
launch_all()

View File

@ -4,14 +4,15 @@ python llm_api_shutdown.py --serve all
可选"all","controller","model_worker","openai_api_server" all表示停止所有服务
"""
import sys
import os
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import subprocess
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--serve",choices=["all","controller","model_worker","openai_api_server"],default="all")
parser.add_argument("--serve", choices=["all", "controller", "model_worker", "openai_api_server"], default="all")
args = parser.parse_args()
@ -23,5 +24,5 @@ else:
serve = f".{args.serve}"
shell_script = base_shell.format(serve)
subprocess.run(shell_script,shell=True,check=True)
print(f"llm api sever --{args.serve} has been shutdown!")
subprocess.run(shell_script, shell=True, check=True)
print(f"llm api sever --{args.serve} has been shutdown!")

View File

@ -48,15 +48,22 @@ class ChatMessage(BaseModel):
schema_extra = {
"example": {
"question": "工伤保险如何办理?",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
"2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n"
"3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n"
"4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n"
"5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n"
"6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
"history": [
[
"工伤保险是什么?",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,"
"由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
]
],
"source_documents": [
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx\n\n\t"
"( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
"出处 [2] ...",
"出处 [3] ...",
],

View File

@ -1,18 +0,0 @@
from langchain.text_splitter import TextSplitter, _split_text_with_regex
from typing import Any, List
class MyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)

View File

@ -1,2 +1,3 @@
from .MyTextSplitter import MyTextSplitter
from .chinese_text_splitter import ChineseTextSplitter
from .ali_text_splitter import AliTextSplitter
from .zh_title_enhance import zh_title_enhance

View File

@ -0,0 +1,27 @@
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
class AliTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
# use_document_segmentation参数指定是否用语义切分文档此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base论文见https://arxiv.org/abs/2107.09278
# 如果使用模型进行文档语义切分那么需要安装modelscope[nlp]pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 考虑到使用了三个模型可能对于低配置gpu不太友好因此这里将模型load进cpu计算有需要的话可以替换device为自己的显卡id
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
from modelscope.pipelines import pipeline
p = pipeline(
task="document-segmentation",
model='damo/nlp_bert_document-segmentation_chinese-base',
device="cpu")
result = p(documents=text)
sent_list = [i for i in result["text"].split("\n\t") if i]
return sent_list

View File

@ -0,0 +1,60 @@
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from configs.model_config import SENTENCE_SIZE
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;!?。!?\?]["’”」』]{0,2})([^;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,.]["’”」』]{0,2})([^,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
return ls