Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev

This commit is contained in:
hzg0601 2023-08-16 16:00:40 +08:00
commit 84b491b8b2
15 changed files with 149 additions and 30 deletions

View File

@ -5,7 +5,6 @@ fschat==0.2.20
transformers transformers
torch~=2.0.0 torch~=2.0.0
fastapi~=0.99.1 fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1 nltk~=3.8.1
uvicorn~=0.23.1 uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0

View File

@ -5,7 +5,6 @@ fschat==0.2.20
transformers transformers
torch~=2.0.0 torch~=2.0.0
fastapi~=0.99.1 fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1 nltk~=3.8.1
uvicorn~=0.23.1 uvicorn~=0.23.1
starlette~=0.27.0 starlette~=0.27.0

View File

@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
import argparse import argparse
import uvicorn import uvicorn
from server.utils import FastAPIOffline as FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat) search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store) update_doc, download_doc, recreate_vector_store,
from server.utils import BaseResponse, ListResponse search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -25,7 +27,8 @@ async def document():
def create_app(): def create_app():
app = FastAPI() app = FastAPI(title="Langchain-Chatchat API Server")
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins # Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域 # 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain # set OPEN_DOMAIN=True in config.py to allow cross-domain
@ -83,6 +86,12 @@ def create_app():
summary="获取知识库内的文件列表" summary="获取知识库内的文件列表"
)(list_docs) )(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc", app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,

View File

@ -1,26 +1,27 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K) VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.chat.utils import wrap_done from server.chat.utils import wrap_done
from server.utils import BaseResponse from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable, List, Optional
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json import json
import os import os
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([], history: List[History] = Body([],
description="历史对话", description="历史对话",
examples=[[ examples=[[
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL model_name=LLM_MODEL
) )
docs = kb.search_docs(query, top_k) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(

View File

@ -1,13 +1,32 @@
import os import os
import urllib import urllib
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
import json import json
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List from typing import List, Dict
from langchain.docstore.document import Document
class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
async def list_docs( async def list_docs(

View File

@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, get_file_detail, delete_file_from_db list_docs_from_db, get_file_detail, delete_file_from_db
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
@ -112,9 +112,10 @@ class KBService(ABC):
def search_docs(self, def search_docs(self,
query: str, query: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
): ):
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings) docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs return docs
@abstractmethod @abstractmethod

View File

@ -81,12 +81,13 @@ class FaissKBService(KBService):
def do_search(self, def do_search(self,
query: str, query: str,
top_k: int, top_k: int,
embeddings: Embeddings, score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]: ) -> List[Document]:
search_index = load_vector_store(self.kb_name, search_index = load_vector_store(self.kb_name,
embeddings=embeddings, embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name)) tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -45,7 +45,8 @@ class MilvusKBService(KBService):
def do_drop_kb(self): def do_drop_kb(self):
self.milvus.col.drop() self.milvus.col.drop()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_milvus(embeddings=embeddings) self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)

View File

@ -43,7 +43,8 @@ class PGKBService(KBService):
''')) '''))
connect.commit() connect.commit()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings) self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k) return self.pg_vector.similarity_search(query, top_k)

View File

@ -4,6 +4,8 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
controller_port = 20001 controller_port = 20001
@ -30,6 +32,8 @@ def create_controller_app(
controller = Controller(dispatch_method) controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app return app
@ -55,7 +59,6 @@ def create_model_worker_app(
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -117,6 +120,8 @@ def create_model_worker_app(
sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app return app
@ -141,6 +146,8 @@ def create_openai_api_app(
app_settings.controller_address = controller_address app_settings.controller_address = controller_address
app_settings.api_keys = api_keys app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app return app

BIN
server/static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -2,14 +2,10 @@ import pydantic
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List
import torch import torch
from fastapi_offline import FastAPIOffline from fastapi import FastAPI
import fastapi_offline
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from typing import Any, Optional
# patch fastapi_offline to use local static assests
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
@ -112,3 +108,81 @@ def iter_over_async(ait, loop):
if done: if done:
break break
yield obj yield obj
def MakeFastAPIOffline(
app: FastAPI,
static_dir = Path(__file__).parent / "static",
static_url = "/static-offline-docs",
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
) -> None:
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
from fastapi import Request
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.staticfiles import StaticFiles
from starlette.responses import HTMLResponse
openapi_url = app.openapi_url
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
def remove_route(url: str) -> None:
'''
remove original route from app
'''
index = None
for i, r in enumerate(app.routes):
if r.path.lower() == url.lower():
index = i
break
if isinstance(index, int):
app.routes.pop(i)
# Set up static file mount
app.mount(
static_url,
StaticFiles(directory=Path(static_dir).as_posix()),
name="static-offline-docs",
)
if docs_url is not None:
remove_route(docs_url)
remove_route(swagger_ui_oauth2_redirect_url)
# Define the doc and redoc pages, pointing at the right files
@app.get(docs_url, include_in_schema=False)
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_swagger_ui_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - Swagger UI",
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
swagger_favicon_url=favicon,
)
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect() -> HTMLResponse:
return get_swagger_ui_oauth2_redirect_html()
if redoc_url is not None:
remove_route(redoc_url)
@app.get(redoc_url, include_in_schema=False)
async def redoc_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_redoc_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - ReDoc",
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
with_google_fonts=False,
redoc_favicon_url=favicon,
)

View File

@ -13,7 +13,11 @@ import os
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
if __name__ == "__main__": if __name__ == "__main__":
st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded") st.set_page_config(
"Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
initial_sidebar_state="expanded",
)
if not chat_box.chat_inited: if not chat_box.chat_inited:
st.toast( st.toast(

View File

@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb", key="selected_kb",
) )
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True) score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"), Markdown("...", in_expander=True, title="知识库匹配结果"),
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
]) ])
text = "" text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k): for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
logger, logger,
@ -312,6 +313,7 @@ class ApiRequest:
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
@ -326,6 +328,7 @@ class ApiRequest:
"query": query, "query": query,
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"top_k": top_k, "top_k": top_k,
"score_threshold": score_threshold,
"history": history, "history": history,
"stream": stream, "stream": stream,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,