update import pkgs and format
This commit is contained in:
parent
22c8f277bf
commit
8a4d9168fa
|
|
@ -1,56 +0,0 @@
|
|||
from langchain.vectorstores import Chroma
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain import LLMChain
|
||||
from langchain.llms import OpenAI
|
||||
from configs.model_config import *
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.callbacks import StreamlitCallbackHandler
|
||||
|
||||
with open("../knowledge_base/samples/content/test.txt") as f:
|
||||
state_of_the_union = f.read()
|
||||
|
||||
# TODO: define params
|
||||
# text_splitter = MyTextSplitter()
|
||||
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=200)
|
||||
texts = text_splitter.split_text(state_of_the_union)
|
||||
|
||||
# TODO: define params
|
||||
# embeddings = MyEmbeddings()
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_kwargs={'device': EMBEDDING_DEVICE})
|
||||
|
||||
docsearch = Chroma.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
metadatas=[{"source": str(i)} for i in range(len(texts))]
|
||||
).as_retriever()
|
||||
|
||||
# test
|
||||
query = "什么是Prompt工程"
|
||||
docs = docsearch.get_relevant_documents(query)
|
||||
# print(docs)
|
||||
|
||||
# prompt_template = PROMPT_TEMPLATE
|
||||
|
||||
llm = OpenAI(model_name=LLM_MODEL,
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
streaming=True)
|
||||
|
||||
# print(PROMPT)
|
||||
prompt = PromptTemplate(input_variables=["input"], template="{input}")
|
||||
chain = LLMChain(prompt=prompt, llm=llm)
|
||||
resp = chain("你好")
|
||||
for x in resp:
|
||||
print(x)
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
template=PROMPT_TEMPLATE,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
|
||||
response = chain({"input_documents": docs, "question": query}, return_only_outputs=False)
|
||||
for x in response:
|
||||
print(response["output_text"])
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
from typing import List
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class MyEmbeddings(Embeddings, BaseModel):
|
||||
size: int
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._get_embedding()
|
||||
|
|
@ -1 +0,0 @@
|
|||
from .MyEmbeddings import MyEmbeddings
|
||||
|
|
@ -3,7 +3,6 @@ openai
|
|||
sentence_transformers
|
||||
fschat==0.2.20
|
||||
transformers
|
||||
transformers_stream_generator
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
fastapi-offline
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import nltk
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||
|
|
@ -13,7 +14,7 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
|||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
update_doc, recreate_vector_store)
|
||||
update_doc, recreate_vector_store)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
|
|
|||
|
|
@ -8,19 +8,20 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from server.chat.utils import History
|
||||
|
||||
|
||||
def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
):
|
||||
history = [History(**h) if isinstance(h, dict) else h for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
) -> AsyncIterable[str]:
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
|
|
@ -72,14 +72,15 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents}, ensure_ascii=False)
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
|
||||
await task
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
import openai
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from pydantic import BaseModel
|
||||
|
|
|
|||
|
|
@ -61,14 +61,14 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -111,7 +111,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
async for token in callback.aiter():
|
||||
# Use server-sent-events to stream the response
|
||||
yield json.dumps({"answer": token,
|
||||
"docs": source_documents}, ensure_ascii=False)
|
||||
"docs": source_documents},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
from functools import wraps
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import contextmanager
|
||||
|
||||
from server.db.base import engine, SessionLocal
|
||||
from server.db.base import SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@ import os
|
|||
import urllib
|
||||
from fastapi import File, Form, Body, UploadFile
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import (get_file_path, validate_kb_name)
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType
|
||||
from typing import Union
|
||||
|
||||
|
||||
async def list_docs(
|
||||
|
|
@ -26,7 +24,7 @@ async def list_docs(
|
|||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
):
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from langchain.docstore.document import Document
|
|||
from server.db.repository.knowledge_base_repository import add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db, delete_file_from_db, doc_exists, \
|
||||
list_docs_from_db
|
||||
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, load_embeddings, KnowledgeFile)
|
||||
from typing import List, Union
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from langchain.vectorstores import PGVector
|
|||
from sqlalchemy import text
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType
|
||||
from server.knowledge_base.utils import KBService, load_embeddings, KnowledgeFile
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
|
||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from configs.model_config import (
|
|||
from functools import lru_cache
|
||||
import sys
|
||||
from text_splitter import zh_title_enhance
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||
import asyncio
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
controller_port = 20001
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
import subprocess
|
||||
|
|
@ -19,15 +20,14 @@ logger = logging.getLogger()
|
|||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
#------multi worker-----------------
|
||||
# ------multi worker-----------------
|
||||
parser.add_argument('--model-path-address',
|
||||
default="THUDM/chatglm2-6b@localhost@20002",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="model path, host, and port, formatted as model-path@host@path")
|
||||
#---------------controller-------------------------
|
||||
# ---------------controller-------------------------
|
||||
|
||||
parser.add_argument("--controller-host", type=str, default="localhost")
|
||||
parser.add_argument("--controller-port", type=int, default=21001)
|
||||
|
|
@ -37,9 +37,9 @@ parser.add_argument(
|
|||
choices=["lottery", "shortest_queue"],
|
||||
default="shortest_queue",
|
||||
)
|
||||
controller_args = ["controller-host","controller-port","dispatch-method"]
|
||||
controller_args = ["controller-host", "controller-port", "dispatch-method"]
|
||||
|
||||
#----------------------worker------------------------------------------
|
||||
# ----------------------worker------------------------------------------
|
||||
|
||||
parser.add_argument("--worker-host", type=str, default="localhost")
|
||||
parser.add_argument("--worker-port", type=int, default=21002)
|
||||
|
|
@ -125,15 +125,15 @@ parser.add_argument("--stream-interval", type=int, default=2)
|
|||
parser.add_argument("--no-register", action="store_true")
|
||||
|
||||
worker_args = [
|
||||
"worker-host","worker-port",
|
||||
"model-path","revision","device","gpus","num-gpus",
|
||||
"max-gpu-memory","load-8bit","cpu-offloading",
|
||||
"gptq-ckpt","gptq-wbits","gptq-groupsize",
|
||||
"gptq-act-order","model-names","limit-worker-concurrency",
|
||||
"stream-interval","no-register",
|
||||
"controller-address"
|
||||
]
|
||||
#-----------------openai server---------------------------
|
||||
"worker-host", "worker-port",
|
||||
"model-path", "revision", "device", "gpus", "num-gpus",
|
||||
"max-gpu-memory", "load-8bit", "cpu-offloading",
|
||||
"gptq-ckpt", "gptq-wbits", "gptq-groupsize",
|
||||
"gptq-act-order", "model-names", "limit-worker-concurrency",
|
||||
"stream-interval", "no-register",
|
||||
"controller-address"
|
||||
]
|
||||
# -----------------openai server---------------------------
|
||||
|
||||
parser.add_argument("--server-host", type=str, default="127.0.0.1", help="host name")
|
||||
parser.add_argument("--server-port", type=int, default=8888, help="port number")
|
||||
|
|
@ -154,13 +154,14 @@ parser.add_argument(
|
|||
type=lambda s: s.split(","),
|
||||
help="Optional list of comma separated API keys",
|
||||
)
|
||||
server_args = ["server-host","server-port","allow-credentials","api-keys",
|
||||
server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
|
||||
"controller-address"
|
||||
]
|
||||
|
||||
args = parser.parse_args()
|
||||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||||
args = argparse.Namespace(**vars(args),**{"controller-address":f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||
args = argparse.Namespace(**vars(args),
|
||||
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||||
|
||||
if args.gpus:
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
|
|
@ -176,7 +177,7 @@ if args.gpus:
|
|||
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
|
||||
|
||||
# 0 log_path
|
||||
#! 1 log的文件名,必须与bash_launch_sh一致
|
||||
# ! 1 log的文件名,必须与bash_launch_sh一致
|
||||
# 2 controller, worker, openai_api_server
|
||||
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
|
||||
sleep 1s;
|
||||
|
|
@ -185,20 +186,20 @@ base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];
|
|||
echo '{2} running' """
|
||||
|
||||
|
||||
def string_args(args,args_list):
|
||||
def string_args(args, args_list):
|
||||
"""将args中的key转化为字符串"""
|
||||
args_str = ""
|
||||
for key, value in args._get_kwargs():
|
||||
# args._get_kwargs中的key以_为分隔符,先转换,再判断是否在指定的args列表中
|
||||
key = key.replace("_","-")
|
||||
key = key.replace("_", "-")
|
||||
if key not in args_list:
|
||||
continue
|
||||
continue
|
||||
# fastchat中port,host没有前缀,去除前缀
|
||||
key = key.split("-")[-1] if re.search("port|host",key) else key
|
||||
key = key.split("-")[-1] if re.search("port|host", key) else key
|
||||
if not value:
|
||||
pass
|
||||
# 1==True -> True
|
||||
elif isinstance(value,bool) and value == True:
|
||||
elif isinstance(value, bool) and value == True:
|
||||
args_str += f" --{key} "
|
||||
elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set):
|
||||
value = " ".join(value)
|
||||
|
|
@ -208,38 +209,40 @@ def string_args(args,args_list):
|
|||
|
||||
return args_str
|
||||
|
||||
|
||||
def launch_worker(item):
|
||||
log_name = item.split("/")[-1].split("\\")[-1].replace("-","_").replace("@","_").replace(".","_")
|
||||
# 先分割model-path-address,在传到string_args中分析参数
|
||||
args.model_path,args.worker_host, args.worker_port = item.split("@")
|
||||
print("*"*80)
|
||||
worker_str_args = string_args(args,worker_args)
|
||||
print(worker_str_args)
|
||||
worker_sh = base_launch_sh.format("model_worker",worker_str_args,LOG_PATH,f"worker_{log_name}")
|
||||
worker_check_sh = base_check_sh.format(LOG_PATH,f"worker_{log_name}","model_worker")
|
||||
subprocess.run(worker_sh,shell=True,check=True)
|
||||
subprocess.run(worker_check_sh,shell=True,check=True)
|
||||
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
|
||||
# 先分割model-path-address,在传到string_args中分析参数
|
||||
args.model_path, args.worker_host, args.worker_port = item.split("@")
|
||||
print("*" * 80)
|
||||
worker_str_args = string_args(args, worker_args)
|
||||
print(worker_str_args)
|
||||
worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
|
||||
worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker")
|
||||
subprocess.run(worker_sh, shell=True, check=True)
|
||||
subprocess.run(worker_check_sh, shell=True, check=True)
|
||||
|
||||
|
||||
def launch_all():
|
||||
controller_str_args = string_args(args,controller_args)
|
||||
controller_sh = base_launch_sh.format("controller",controller_str_args,LOG_PATH,"controller")
|
||||
controller_check_sh = base_check_sh.format(LOG_PATH,"controller","controller")
|
||||
subprocess.run(controller_sh,shell=True,check=True)
|
||||
subprocess.run(controller_check_sh,shell=True,check=True)
|
||||
controller_str_args = string_args(args, controller_args)
|
||||
controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
|
||||
controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
|
||||
subprocess.run(controller_sh, shell=True, check=True)
|
||||
subprocess.run(controller_check_sh, shell=True, check=True)
|
||||
|
||||
if isinstance(args.model_path_address, str):
|
||||
launch_worker(args.model_path_address)
|
||||
else:
|
||||
for idx,item in enumerate(args.model_path_address):
|
||||
for idx, item in enumerate(args.model_path_address):
|
||||
print(f"开始加载第{idx}个模型:{item}")
|
||||
launch_worker(item)
|
||||
|
||||
server_str_args = string_args(args,server_args)
|
||||
server_sh = base_launch_sh.format("openai_api_server",server_str_args,LOG_PATH,"openai_api_server")
|
||||
server_check_sh = base_check_sh.format(LOG_PATH,"openai_api_server","openai_api_server")
|
||||
subprocess.run(server_sh,shell=True,check=True)
|
||||
subprocess.run(server_check_sh,shell=True,check=True)
|
||||
server_str_args = string_args(args, server_args)
|
||||
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
|
||||
server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
|
||||
subprocess.run(server_sh, shell=True, check=True)
|
||||
subprocess.run(server_check_sh, shell=True, check=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_all()
|
||||
|
|
@ -5,13 +5,14 @@ python llm_api_shutdown.py --serve all
|
|||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--serve",choices=["all","controller","model_worker","openai_api_server"],default="all")
|
||||
parser.add_argument("--serve", choices=["all", "controller", "model_worker", "openai_api_server"], default="all")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -23,5 +24,5 @@ else:
|
|||
serve = f".{args.serve}"
|
||||
shell_script = base_shell.format(serve)
|
||||
|
||||
subprocess.run(shell_script,shell=True,check=True)
|
||||
subprocess.run(shell_script, shell=True, check=True)
|
||||
print(f"llm api sever --{args.serve} has been shutdown!")
|
||||
|
|
@ -48,15 +48,22 @@ class ChatMessage(BaseModel):
|
|||
schema_extra = {
|
||||
"example": {
|
||||
"question": "工伤保险如何办理?",
|
||||
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
||||
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
|
||||
"2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n"
|
||||
"3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n"
|
||||
"4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n"
|
||||
"5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n"
|
||||
"6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
||||
"history": [
|
||||
[
|
||||
"工伤保险是什么?",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,"
|
||||
"由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
||||
]
|
||||
],
|
||||
"source_documents": [
|
||||
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
||||
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t"
|
||||
"( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
||||
"出处 [2] ...",
|
||||
"出处 [3] ...",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
from langchain.text_splitter import TextSplitter, _split_text_with_regex
|
||||
from typing import Any, List
|
||||
|
||||
|
||||
class MyTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
from .MyTextSplitter import MyTextSplitter
|
||||
from .chinese_text_splitter import ChineseTextSplitter
|
||||
from .ali_text_splitter import AliTextSplitter
|
||||
from .zh_title_enhance import zh_title_enhance
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from langchain.text_splitter import CharacterTextSplitter
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
class AliTextSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
# use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278
|
||||
# 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
# 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", r"\n", text)
|
||||
text = re.sub('\s', " ", text)
|
||||
text = re.sub("\n\n", "", text)
|
||||
from modelscope.pipelines import pipeline
|
||||
|
||||
p = pipeline(
|
||||
task="document-segmentation",
|
||||
model='damo/nlp_bert_document-segmentation_chinese-base',
|
||||
device="cpu")
|
||||
result = p(documents=text)
|
||||
sent_list = [i for i in result["text"].split("\n\t") if i]
|
||||
return sent_list
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
from langchain.text_splitter import CharacterTextSplitter
|
||||
import re
|
||||
from typing import List
|
||||
from configs.model_config import SENTENCE_SIZE
|
||||
|
||||
|
||||
class ChineseTextSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
self.sentence_size = sentence_size
|
||||
|
||||
def split_text1(self, text: str) -> List[str]:
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", "\n", text)
|
||||
text = re.sub('\s', ' ', text)
|
||||
text = text.replace("\n\n", "")
|
||||
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
|
||||
sent_list = []
|
||||
for ele in sent_sep_pattern.split(text):
|
||||
if sent_sep_pattern.match(ele) and sent_list:
|
||||
sent_list[-1] += ele
|
||||
elif ele:
|
||||
sent_list.append(ele)
|
||||
return sent_list
|
||||
|
||||
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
|
||||
if self.pdf:
|
||||
text = re.sub(r"\n{3,}", r"\n", text)
|
||||
text = re.sub('\s', " ", text)
|
||||
text = re.sub("\n\n", "", text)
|
||||
|
||||
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
|
||||
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
|
||||
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
|
||||
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
|
||||
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
|
||||
text = text.rstrip() # 段尾如果有多余的\n就去掉它
|
||||
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
|
||||
ls = [i for i in text.split("\n") if i]
|
||||
for ele in ls:
|
||||
if len(ele) > self.sentence_size:
|
||||
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
|
||||
ele1_ls = ele1.split("\n")
|
||||
for ele_ele1 in ele1_ls:
|
||||
if len(ele_ele1) > self.sentence_size:
|
||||
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
|
||||
ele2_ls = ele_ele2.split("\n")
|
||||
for ele_ele2 in ele2_ls:
|
||||
if len(ele_ele2) > self.sentence_size:
|
||||
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
|
||||
ele2_id = ele2_ls.index(ele_ele2)
|
||||
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
|
||||
ele2_id + 1:]
|
||||
ele_id = ele1_ls.index(ele_ele1)
|
||||
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
|
||||
|
||||
id = ls.index(ele)
|
||||
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
|
||||
return ls
|
||||
Loading…
Reference in New Issue