Merge branch 'dev' of github.com:chatchat-space/Langchain-Chatchat into dev

This commit is contained in:
hzg0601 2023-08-16 16:00:40 +08:00
commit 84b491b8b2
15 changed files with 149 additions and 30 deletions

View File

@ -5,7 +5,6 @@ fschat==0.2.20
transformers
torch~=2.0.0
fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0

View File

@ -5,7 +5,6 @@ fschat==0.2.20
transformers
torch~=2.0.0
fastapi~=0.99.1
fastapi-offline
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0

View File

@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from server.utils import FastAPIOffline as FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -25,7 +27,8 @@ async def document():
def create_app():
app = FastAPI()
app = FastAPI(title="Langchain-Chatchat API Server")
MakeFastAPIOffline(app)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
@ -83,6 +86,12 @@ def create_app():
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,

View File

@ -1,26 +1,27 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json
import os
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([],
description="历史对话",
examples=[[
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = kb.search_docs(query, top_k)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(

View File

@ -1,13 +1,32 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List
from typing import List, Dict
from langchain.docstore.document import Document
class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
async def list_docs(

View File

@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, get_file_detail, delete_file_from_db
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
@ -112,9 +112,10 @@ class KBService(ABC):
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings)
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
@abstractmethod

View File

@ -81,12 +81,13 @@ class FaissKBService(KBService):
def do_search(self,
query: str,
top_k: int,
embeddings: Embeddings,
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,

View File

@ -45,7 +45,8 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)

View File

@ -43,7 +43,8 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)

View File

@ -4,6 +4,8 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
host_ip = "0.0.0.0"
controller_port = 20001
@ -30,6 +32,8 @@ def create_controller_app(
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app
@ -55,7 +59,6 @@ def create_model_worker_app(
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
import argparse
parser = argparse.ArgumentParser()
@ -117,6 +120,8 @@ def create_model_worker_app(
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
@ -141,6 +146,8 @@ def create_openai_api_app(
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app

BIN
server/static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -2,14 +2,10 @@ import pydantic
from pydantic import BaseModel
from typing import List
import torch
from fastapi_offline import FastAPIOffline
import fastapi_offline
from fastapi import FastAPI
from pathlib import Path
import asyncio
# patch fastapi_offline to use local static assests
fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static"
from typing import Any, Optional
class BaseResponse(BaseModel):
@ -112,3 +108,81 @@ def iter_over_async(ait, loop):
if done:
break
yield obj
def MakeFastAPIOffline(
app: FastAPI,
static_dir = Path(__file__).parent / "static",
static_url = "/static-offline-docs",
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
) -> None:
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
from fastapi import Request
from fastapi.openapi.docs import (
get_redoc_html,
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
from fastapi.staticfiles import StaticFiles
from starlette.responses import HTMLResponse
openapi_url = app.openapi_url
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
def remove_route(url: str) -> None:
'''
remove original route from app
'''
index = None
for i, r in enumerate(app.routes):
if r.path.lower() == url.lower():
index = i
break
if isinstance(index, int):
app.routes.pop(i)
# Set up static file mount
app.mount(
static_url,
StaticFiles(directory=Path(static_dir).as_posix()),
name="static-offline-docs",
)
if docs_url is not None:
remove_route(docs_url)
remove_route(swagger_ui_oauth2_redirect_url)
# Define the doc and redoc pages, pointing at the right files
@app.get(docs_url, include_in_schema=False)
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_swagger_ui_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - Swagger UI",
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
swagger_favicon_url=favicon,
)
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
async def swagger_ui_redirect() -> HTMLResponse:
return get_swagger_ui_oauth2_redirect_html()
if redoc_url is not None:
remove_route(redoc_url)
@app.get(redoc_url, include_in_schema=False)
async def redoc_html(request: Request) -> HTMLResponse:
root = request.scope.get("root_path")
favicon = f"{root}{static_url}/favicon.png"
return get_redoc_html(
openapi_url=f"{root}{openapi_url}",
title=app.title + " - ReDoc",
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
with_google_fonts=False,
redoc_favicon_url=favicon,
)

View File

@ -13,7 +13,11 @@ import os
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
if __name__ == "__main__":
st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded")
st.set_page_config(
"Langchain-Chatchat WebUI",
os.path.join("img", "chatchat_icon_blue_square_v2.png"),
initial_sidebar_state="expanded",
)
if not chat_box.chat_inited:
st.toast(

View File

@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答":
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
if error_msg := check_error_msg(t): # check whether error occured
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(t): # check whether error occured
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
@ -312,6 +313,7 @@ class ApiRequest:
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None,
@ -326,6 +328,7 @@ class ApiRequest:
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"local_doc_url": no_remote_api,