From 67b8ebef522ac9952897800dda9b4d9a3e713bc7 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 16 Aug 2023 13:18:58 +0800 Subject: [PATCH 1/2] =?UTF-8?q?update=20api=20and=20webui:=201.=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0search=5Fdocs=E6=8E=A5=E5=8F=A3=EF=BC=8C?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E5=8E=9F=E5=A7=8B=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E6=96=87=E6=A1=A3=EF=BC=8Cclose=20#1103=202.?= =?UTF-8?q?=20=E4=B8=BAFAISS=E6=A3=80=E7=B4=A2=E5=A2=9E=E5=8A=A0score=5Fth?= =?UTF-8?q?reshold=E5=8F=82=E6=95=B0=E3=80=82milvus=E5=92=8CPG=E6=9A=82?= =?UTF-8?q?=E4=B8=8D=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/api.py | 11 ++++++++- server/chat/knowledge_base_chat.py | 9 ++++--- server/knowledge_base/kb_doc_api.py | 23 ++++++++++++++++-- server/knowledge_base/kb_service/base.py | 5 ++-- .../kb_service/faiss_kb_service.py | 5 ++-- .../kb_service/milvus_kb_service.py | 3 ++- .../kb_service/pg_kb_service.py | 3 ++- server/static/favicon.png | Bin 0 -> 7299 bytes webui_pages/dialogue/dialogue.py | 8 +++--- webui_pages/utils.py | 3 +++ 10 files changed, 53 insertions(+), 17 deletions(-) create mode 100644 server/static/favicon.png diff --git a/server/api.py b/server/api.py index 458b1d7..d86fcbb 100644 --- a/server/api.py +++ b/server/api.py @@ -14,8 +14,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store) + update_doc, download_doc, recreate_vector_store, + search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse +from typing import List + nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -83,6 +86,12 @@ def create_app(): summary="获取知识库内的文件列表" )(list_docs) + app.post("/knowledge_base/search_docs", + tags=["Knowledge Base Management"], + response_model=List[DocumentWithScore], + summary="搜索知识库" + )(search_docs) + app.post("/knowledge_base/upload_doc", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ecabf0..84c62f0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,26 +1,27 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable +from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json import os from urllib.parse import urlencode +from server.knowledge_base.kb_doc_api import search_docs def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), history: List[History] = Body([], description="历史对话", examples=[[ @@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) - docs = kb.search_docs(query, top_k) + docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) chat_prompt = ChatPromptTemplate.from_messages( diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 3f27fb1..0bf2cb7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,13 +1,32 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory -from typing import List +from typing import List, Dict +from langchain.docstore.document import Document + + +class DocumentWithScore(Document): + score: float = None + + +def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + ) -> List[DocumentWithScore]: + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + docs = kb.search_docs(query, top_k, score_threshold) + data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + + return data async def list_docs( diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ec1c692..d506f63 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import ( list_docs_from_db, get_file_detail, delete_file_from_db ) -from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, +from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, EMBEDDING_DEVICE, EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, @@ -112,9 +112,10 @@ class KBService(ABC): def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, ): embeddings = self._load_embeddings() - docs = self.do_search(query, top_k, embeddings) + docs = self.do_search(query, top_k, score_threshold, embeddings) return docs @abstractmethod diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 0ef820a..5c8376f 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -81,12 +81,13 @@ class FaissKBService(KBService): def do_search(self, query: str, top_k: int, - embeddings: Embeddings, + score_threshold: float = SCORE_THRESHOLD, + embeddings: Embeddings = None, ) -> List[Document]: search_index = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) + docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 6f1c392..f9c40c0 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,7 +45,8 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_milvus(embeddings=embeddings) return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 82511bb..a3126ec 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,7 +43,8 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return self.pg_vector.similarity_search(query, top_k) diff --git a/server/static/favicon.png b/server/static/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..5de8ee80b5f219484c90b0b4604cfd1e85b58e3d GIT binary patch literal 7299 zcmds6c~Dd7mXD2!wyo&+ghpAkC7`0vf`DNQg6u+21Pl;R2_U-$0)#F4P@@VOKy1i$X|%kOT-2AcV|y=GE(YQ~vR)rlzK9Zq=<@XZg;(v9)7rOp+g6Yx+?Qq#61Rh_tX%A^tPWnsL<9td`a>W` z9wk09vwQFImpcg85^GPD_JiXO$;Qat`))iwrufb>NV{t0@LM5%gjIWs5!p|9D#$J&9ed$j6r^Y11T^TeHBn|3RQ)?AhgZza%^wwkC6P=zF zokNKWN;db`FqyoI@9nPcIa++^%o$a;PnQonH#pzF^HaW7!nzd&?xb_a+wE6OyPw>C z^nk*KJR>k9LsN@hZ1=-d;*GrYJqN3HRP9d_CC;9%*n4p0C+}%hna$qpRgK1;o~yPN zGQ;aPMtUBbeZk28oS=txjV?Ku;c{$~FmTan%jGYMQ(jA)e)HC8xrar_6C)7$l(*M% zXafnOs)7O*BW6e>*%Y9HdD^BO@E%;2tydjw&w^j@S+#1wgP-l!?U0cWh|b|}@21!i zcr?hAkFs~NmgjEUxoh(ecOpr}AWJ#Q`cjltL|B-Acof76<9{W}AAK+;DC*ilTYD!L zk6XKxAdvlD_Giyrj2+^S;POz(QrUcrM$TEq?X>c$U-ngQuDIa1nQ^V>fR%iTkJi440*)zJGJH)+YaL^BxV2N{#H7BFV&x`Bg1FF-ND1J8?3-{ak>+ zj2qe29wKPu7JPila&~gs1PSiFY6F2-DnPcGK@KiLZg$K4o%rwce+vJU|Ec_)#-Q*| z`9E3uj`Bar|3Um8dy3fny%Bt`Q(%w((dj=Y@?RhR_vh+=$&(AuQ+ZQ|zsN^42IoB~ zj#tvZg&CT>y((n+#$=qyeCvcv57&DO4SaXNmxTz2jc)TAPu zC{jPnJ3GVoC8JLTpQ3K~e(W|aeGzR$?kG|>zYCaC_JKRJqpxAG+Z3!X0%| zxjwLWuqetBA+<^R632m5hc=G_@J>B>a+x5Hkw}}qS)9RGBy+Glg6V2VIWW$sSF(BjneuDfzelrR>XUetbOsvMRqf%1oPnnMSsq!-@W0OMDx$_lXP}$GHyDSwfI)Xo5-fD^H ziISx{Q254Q>a zBw+x{l>(#)66{opYK@&|##f`2W7-;$B%?a^d3~k0l^?{WgyEIPFpJ2~0{O|+Eu;SB z#ep6oN`P9uJcKL{xvd7-9S518!awEzuyj+S;sNBC71tdI zMwzd!mJ6W>hN~}cog)FVYs2;)Ty&AU`;v5g>szAEeb4IxlP(&3AOp-Wup zfPHC4q}qoCNH4*#!1Tq6`vbo9SIqk2%khV z*zEF*H&@l&If;Ir#(8efbuN@*$7VBT$3ixY5xE4FxEEntZ>v53@CT8&%qI#^inEff zltTJg_DdO`oeS5>GWA69hAF7tBHmxAk$!$O!G7Sk7$_HxmfYHMCD8Q9>yz(^qs$Yr zY+2WMSgH^nU5L)7^i_7sU}cv$YQ8sctfW_+jN!w|pFbLIwp7@6EuRAX;Yy%58gxw> zR5@=%Ir~#1ty0;-)}*DkcyC@m@LL&F)pMx8Vv79xthlIHY?Ld*q-;FF4B)zFfL=#K z*<`_+4F2}pYRCRScsRK8ymFbNd@X=l`+|<$nEw&Dl}zxjH<0OOTJ3)$sQJ21X2Bde*ye0MLRG*u)|S6XH*T$w-vLU z18H*P6l;D348Rv?^RaY&vpRlV8y|hp{Gu z>`!f8+!kbs7Z)cYC<_NE;FMT;|EgEFwv1kD=)ezlrBU0l8+-;qJKTOWr)ND&x-x}M zef=LpKI2T+uKS7r+L0YGIb!qdP;B@lMf573z>{U4GBh0X9~HaMCSf&N;K;i_7vhj2 z@BVzTmb@==m1nvrV7F`Ma%lqGLD+aOYtJ^bi?34E?3a~mu4m1|4-@66_NDEVG;Gt# zBXsM8qCjtB{t|7>4ftTg+I{?Q?lzD~aSE5Xk9sb8sD9-Q?}B}*Vxi2;b?n?|A{;4s z25w1)lbti;)SzF(_=Jw5W@_R))Ner%P@$1#$VRWBoTY*r=cOdIR%7)e+5Uy@N|5Qq z8vEA3^Z5J}_$4|NP?gsT+s~2tGBbhUI(6w-XD;-2*QDtpb2 zsuLEGA{8^2J!jR1i&D1qms!#HLx66yWKqUox3E9$X?`^PWg@P|Niia0T{Mo6$z*9x z%f?1Hz7*Z!EwY>38$b`26A?%f-F$v19UeMe0*iIR504l+rCv5 zYsim$QLwqK*2c+p;s{5i*2_{{Wnjzzlt{G6Rs{l>&{o0n>Lb~!FhkHCaTarD@I21c z7SXXf*#b~l8mdG1RUwidrWgByecP#nI%-g5=R9hJys^3^Ve{o8c8y$z0y%}?RpZqg zbZ?8S32LknA6_`Sq^kci5aIjyc@$q3E=KB2l@uf)< zN39}!e$?z-pKek70GpaZf`S`Pn0p2IdLY-T@iQ$3^X{E}txmwLZmW^cGA_`qo^4gZ zx}L*+h9X2P>+J|8HVVG#S4DWtJ{dTz4U-Bv=U?1a&ppn1uW~(IZRSGjGVoB9sOoN* z>Um0F^h4~tK@_DsZ>SX)_u3G2i}pRx#2Ar z6%|HYUBeR%m9LNe&d14+o#w}pTnf8<9`H{mVGU+J9)jVr&?5$pk*ls{&}3bMz}spL zGd*R*jSx-nDC^5!RXJ>x9?4m>(TnlHOzsUph3V`kysUJX){*IoNhdQ}t;C=54zy@k z={lA+y5%y3Jq9pLdNDrcHAPdV;y(0;Lc78pTN zd5xg)IVBBgIj;#Up?%tV5zcjR-Sa7D2Vr<=pCawcMD|4ET ztYaq7uQ%)qCmr!9kCEKGsj8MYwh9?vTgnLUyAY*Wr{BIPUU&z3}u-}$y!&2?Aj;AJ8=j)cP3hH}jUvajl)irhZmI#=9?E#)L93{Vt+SeolL!v?a72-@k+aZ%b*tZJRde6Hl9j=fI8=UI}Ya2ZIG z5}o>;+&#&r1FK1X7}hqf85qpQ&CCs8{QK-VaDOFKEDj7{D%oH12^(V$f`nh|!mC)9 zuXxvm&@OV(jRwq^{78;QZJOOK#>o1fxMx2a$&3I4;AFg_{M)0|;H1k9Ai(@+P&k5? zPK<2bG1Fc_+!O+BjMF^9jGg@}NtR1rv%$0^hRcCrLT6cK2A%%wqOxA9QXg}4)MqVh zett~t+%ml-Y#z6Ub=X*YF~0jlMceuTUVj5&Z@9Kk>2@$EO#w!!c(HP3Ui{hK@c9Mr z2$n+heo2oez;+_|-PNKp@`ugLbNVih)dHAQtug5r$k{TLhf#uy8~)k)Ms;{J=Iw7u zHekfvj*m5;#)i?lb~TF_lhH$mfCeEH-=v#kU{WW2hi3DZum={e5mlCx$IBl($sezu z{8<#bu=WW(7a{cqFErP?@Jo5|O%YV|3t@(Aja5s+$Bh_x*ay%_^!Wfjyow?A1QZYa zrK^G3*zy#WnWBcMfpvBLV5wlho0TV+J~V|H!onHWwJOpzMOkNV{sYhISXp)|agT0) zOP=ubX=M1|xHLLd%Q4z>k-7MVLD)O13&w@SH@wx4rSm(-9&P}npf4cyvf{UciJ);c!^t12iKb7y>fj%dZG!q%lnH0PuD>`Z9=@J zMU*HB5qKk*ImLQ5Gu`0Kw|Usbg)89QeNT_eSHif$XB2h_c66 zxza9{p*xZ?Y76-K|F)!1^ZFW#gx&rPKzdhre0 z;gRjsEd|mkeb{3ExW+MV!nHtJa4)ueyuwHIMyI^3^=7iXITNl49W$$sl}$2<&5GKx zJ**WBnDY<5CccP3q%l64?vA^>DKW8&{c0yMG3H2GAYw9m&Oz@6VW2s=`@N+x_f|2A zg|6Orzfq(;UE;SC{3)CLB$ZgH-~PjGwcK}rSffr07L-5>t+kFsn-P?JoCO8coo}B|Jyo?6%tOl#@R!!C-{t&8-J?N#4<+ zy`|(>zpn-c`{aLIW5J11e4yIF#aG)#WDdQ7GW2H!;?4nwwMAOi#_vB%t}ZN_0E&U6 zE#U3{aIfGTCJ@u}Wc;iSGx~)fx$E&quLv!L#OpUnz9YG&(zG1V-yP#GE#qSwBSfr3 z(G@H< z8QDkrd}-ZeW1hW`&&4OT*EKUzbqk`dd60PKF+O_yj&W%u)sgbEe>{YoV%#I_9v>^( z3Zj{%cM74*8JN{@;HV$|FwP${zpjv~04na(PoF-$qQ_*UxyxE1F(=}CfnT=P|InR@-TOd(#~+2+Ce6xW84XKL6S zCU2>b04=OBTFWrOFAt32Eqeo3sY!`UK>+9AqDAviY__t68M1xbe*qJJf`NY-EPM}X zfJovWVB()J%ZduMvEV^`v&wU%_TeYqyOi^0;Y^j_kyRsTttkjG5||H*ZL=190(pAU zWi{~y4D%YRy4$rn4GxE`K3VYgpcZgY3Q0?;?94`9U}D6u(KN%omu91`oxr3i;%+S)~W3A!$hIm(4gO9L~|}P&razxDzP1A~X&> zZ}NwSxPVZL$6L-d(#6fdSQXu_dntHRTmT*B5{JrkxciesG_oizbCn|{Rxo_=wSB`l zAlzdx;7&pAuYTdZr@B;*qzh0)Ke2Pa;KC|%O(I9w6w=VKJ1+Qbw34PStiZs(x7=?b z%U8>c<<42aUd!>W459TH+cPtla0ix#N1~4{e##rrR5BT}s9P7#p|3OAyeO}u?P^pm zF>0B`Mxi82gws%5tX}KMilbaJZZz)l z5ub%_?7N@!>O^quPxeV~ZX}MWrJLsrOKAKHQ{~83sey9RZ!i}h@V1OzS$>MTOIp_M zwDjn*ffK!>Yx(D=%5h~fv$|DXj>T$KklkI??9dwTRceDHk6mlBFQvMI`?}kyfRG@p zLbtPOY{_LhbYmzkupTryP>`oma|K5gu?zdTJT^RheY5^Ew-YStjb8hfx2;mqonlcQ zx^ExhZa?SDYC=c}x*P^byRtUwDP!7$n7#T8@jJ96^&9p#QF!3UadI+@db_*r85!GK z8`DcZx!Ss;q@e)ZB?l_Jd|N59Q}~T|nmFU0((or8g%MGm&d)iaV#mlQbmJPtuAlRM z@~n;$vn5nFcpzDFTx|MKCCTarMmyLeHQWSt_DA2&xk63&;PIeqzwT(`e$gk83c?GR zGrje(&Pl8eUVN#p^SORlOO2bZX;fF1;rKMVISrT{Y@Tl;t28ludIM}K$NsRCRT`fj z40yekKk2xaxSgtTU{^J&&Q`(qORp-?hp9Q{*?0U%& zc>I~EqmTC3Dz^SwwpJZvm+3>JWlp~l5twSSb?S6C?eq>*N&XW~A5hXT8=-zB-g@fR zmMYX3duoJOz`WS0XV7_$@&>9L-4%%M*w^4bS{_jNGchzNGwl8XZNO}Ht z5c6xx4&Be)XcvU&>AJ$C&5dK~4?~Tg&^V_t7co;MKozp)Xd81=)t@zKJzoTP2ARhR-k`~4*rZSB7UUUfzSj02);1`!b=t;1K>^F3h1kITbRvQ zs-z&yJcfx%xZxuCnHH}|LWICwBA#t+diay6HG5ColU*IzB$dcBeDgq=u#6v@(k+J- zq*VfQypV{9b#M?;o$Hcc79n-vQJ9vDYLT&|b;C99PA3m+dsAU}eXO{$a&2H@@CJ~{ zJw2{!sW7))y35HYsX8u8RyXhLlH4tqKzUcDiAb+#n%o#5c50=A9CBcpmX@y@eVOk0 zKqW&3m&4jcmS4FbAZ4py&M)+SX!GsYYSFOu7Qk@N3lC3hO8t=i)qbGlSma! z5t90cvKJK0BYk<9FJ<|68eTT2A91FLW)1-5;@lIU^^@+tdaL2WRrSKSxhzQ&5TU`9 zQDX5QT)}F`>$6jvZtfqS#Qwjmq<^Qt|F75VLCx Date: Wed, 16 Aug 2023 14:20:09 +0800 Subject: [PATCH 2/2] update llm_api and api server: 1. fastchat's controller/model_worker/api_server use swagger UI offline. 2. add custom title and icon. 3. remove fastapi-offline dependence --- requirements.txt | 1 - requirements_api.txt | 1 - server/api.py | 6 ++-- server/llm_api.py | 9 ++++- server/utils.py | 86 ++++++++++++++++++++++++++++++++++++++++---- webui.py | 6 +++- 6 files changed, 96 insertions(+), 13 deletions(-) diff --git a/requirements.txt b/requirements.txt index f2e1d65..6c013e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ fschat==0.2.20 transformers torch~=2.0.0 fastapi~=0.99.1 -fastapi-offline nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 diff --git a/requirements_api.txt b/requirements_api.txt index f077c94..9b45aac 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -5,7 +5,6 @@ fschat==0.2.20 transformers torch~=2.0.0 fastapi~=0.99.1 -fastapi-offline nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 diff --git a/server/api.py b/server/api.py index d86fcbb..800680c 100644 --- a/server/api.py +++ b/server/api.py @@ -7,7 +7,6 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN import argparse import uvicorn -from server.utils import FastAPIOffline as FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, @@ -16,7 +15,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.utils import BaseResponse, ListResponse +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List @@ -28,7 +27,8 @@ async def document(): def create_app(): - app = FastAPI() + app = FastAPI(title="Langchain-Chatchat API Server") + MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 # set OPEN_DOMAIN=True in config.py to allow cross-domain diff --git a/server/llm_api.py b/server/llm_api.py index 0a7d3b0..e1013ed 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,6 +4,8 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger +from server.utils import MakeFastAPIOffline + host_ip = "0.0.0.0" controller_port = 20001 @@ -30,6 +32,8 @@ def create_controller_app( controller = Controller(dispatch_method) sys.modules["fastchat.serve.controller"].controller = controller + MakeFastAPIOffline(app) + app.title = "FastChat Controller" return app @@ -55,7 +59,6 @@ def create_model_worker_app( import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id - from fastchat.serve import model_worker import argparse parser = argparse.ArgumentParser() @@ -117,6 +120,8 @@ def create_model_worker_app( sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config + MakeFastAPIOffline(app) + app.title = f"FastChat LLM Server ({LLM_MODEL})" return app @@ -141,6 +146,8 @@ def create_openai_api_app( app_settings.controller_address = controller_address app_settings.api_keys = api_keys + MakeFastAPIOffline(app) + app.title = "FastChat OpeanAI API Server" return app diff --git a/server/utils.py b/server/utils.py index e1a23d1..c0f11a5 100644 --- a/server/utils.py +++ b/server/utils.py @@ -2,14 +2,10 @@ import pydantic from pydantic import BaseModel from typing import List import torch -from fastapi_offline import FastAPIOffline -import fastapi_offline +from fastapi import FastAPI from pathlib import Path import asyncio - - -# patch fastapi_offline to use local static assests -fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static" +from typing import Any, Optional class BaseResponse(BaseModel): @@ -112,3 +108,81 @@ def iter_over_async(ait, loop): if done: break yield obj + + +def MakeFastAPIOffline( + app: FastAPI, + static_dir = Path(__file__).parent / "static", + static_url = "/static-offline-docs", + docs_url: Optional[str] = "/docs", + redoc_url: Optional[str] = "/redoc", +) -> None: + """patch the FastAPI obj that doesn't rely on CDN for the documentation page""" + from fastapi import Request + from fastapi.openapi.docs import ( + get_redoc_html, + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, + ) + from fastapi.staticfiles import StaticFiles + from starlette.responses import HTMLResponse + + openapi_url = app.openapi_url + swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url + + def remove_route(url: str) -> None: + ''' + remove original route from app + ''' + index = None + for i, r in enumerate(app.routes): + if r.path.lower() == url.lower(): + index = i + break + if isinstance(index, int): + app.routes.pop(i) + + # Set up static file mount + app.mount( + static_url, + StaticFiles(directory=Path(static_dir).as_posix()), + name="static-offline-docs", + ) + + if docs_url is not None: + remove_route(docs_url) + remove_route(swagger_ui_oauth2_redirect_url) + + # Define the doc and redoc pages, pointing at the right files + @app.get(docs_url, include_in_schema=False) + async def custom_swagger_ui_html(request: Request) -> HTMLResponse: + root = request.scope.get("root_path") + favicon = f"{root}{static_url}/favicon.png" + return get_swagger_ui_html( + openapi_url=f"{root}{openapi_url}", + title=app.title + " - Swagger UI", + oauth2_redirect_url=swagger_ui_oauth2_redirect_url, + swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js", + swagger_css_url=f"{root}{static_url}/swagger-ui.css", + swagger_favicon_url=favicon, + ) + + @app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False) + async def swagger_ui_redirect() -> HTMLResponse: + return get_swagger_ui_oauth2_redirect_html() + + if redoc_url is not None: + remove_route(redoc_url) + + @app.get(redoc_url, include_in_schema=False) + async def redoc_html(request: Request) -> HTMLResponse: + root = request.scope.get("root_path") + favicon = f"{root}{static_url}/favicon.png" + + return get_redoc_html( + openapi_url=f"{root}{openapi_url}", + title=app.title + " - ReDoc", + redoc_js_url=f"{root}{static_url}/redoc.standalone.js", + with_google_fonts=False, + redoc_favicon_url=favicon, + ) diff --git a/webui.py b/webui.py index d84da42..99db3f6 100644 --- a/webui.py +++ b/webui.py @@ -13,7 +13,11 @@ import os api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) if __name__ == "__main__": - st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded") + st.set_page_config( + "Langchain-Chatchat WebUI", + os.path.join("img", "chatchat_icon_blue_square_v2.png"), + initial_sidebar_state="expanded", + ) if not chat_box.chat_inited: st.toast(