From 6e7078cf24b6bb341ae7edb275e74027e942500a Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 13 May 2023 11:02:27 +0800 Subject: [PATCH 1/4] update api.py --- api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/api.py b/api.py index 9c93c0e..00d8600 100644 --- a/api.py +++ b/api.py @@ -2,15 +2,12 @@ import argparse import json import os import shutil -import subprocess -import tempfile from typing import List, Optional import nltk import pydantic import uvicorn from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket -from fastapi.openapi.utils import get_openapi from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing_extensions import Annotated @@ -144,7 +141,7 @@ async def upload_files( async def list_docs( - knowledge_base_id: Optional[str] = Query(description="Knowledge Base Name", example="kb1") + knowledge_base_id: Optional[str] = Query(default=None, description="Knowledge Base Name", example="kb1") ): if knowledge_base_id: local_doc_folder = get_folder_path(knowledge_base_id) From 621a0fe686a1826799393371f7727026f79707e2 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 13 May 2023 11:45:57 +0800 Subject: [PATCH 2/4] update cli_demo.py --- chains/local_doc_qa.py | 6 +++++- cli_demo.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index a454931..d4cdcb0 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -39,7 +39,11 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE): def write_check_file(filepath, docs): - fout = open('load_file.txt', 'a') + folder_path = os.path.join(os.path.dirname(filepath), "tmp_files") + if not os.path.exists(folder_path): + os.makedirs(folder_path) + fp = os.path.join(folder_path, 'load_file.txt') + fout = open(fp, 'a') fout.write("filepath=%s,len=%s" % (filepath, len(docs))) fout.write('\n') for i in docs: diff --git a/cli_demo.py b/cli_demo.py index eb0f7e2..ea5a895 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -31,13 +31,13 @@ if __name__ == "__main__": chat_history=history, streaming=STREAMING): if STREAMING: - logger.info(resp["result"][last_print_len:]) + print(resp["result"][last_print_len:], end="", flush=True) last_print_len = len(resp["result"]) else: - logger.info(resp["result"]) + print(resp["result"]) if REPLY_WITH_SOURCE: source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n""" for inum, doc in enumerate(resp["source_documents"])] - logger.info("\n\n" + "\n\n".join(source_text)) + print("\n\n" + "\n\n".join(source_text)) From 80854c4bcd52f61ff45fa5c47de23d7b4a3298b7 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 13 May 2023 12:36:48 +0800 Subject: [PATCH 3/4] update loaders --- loader/image_loader.py | 2 +- loader/pdf_loader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/loader/image_loader.py b/loader/image_loader.py index 1013e82..b14899a 100644 --- a/loader/image_loader.py +++ b/loader/image_loader.py @@ -30,7 +30,7 @@ class UnstructuredPaddleImageLoader(UnstructuredFileLoader): if __name__ == "__main__": - filepath = "../content/samples/test.jpg" + filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.jpg") loader = UnstructuredPaddleImageLoader(filepath, mode="elements") docs = loader.load() for doc in docs: diff --git a/loader/pdf_loader.py b/loader/pdf_loader.py index a27eec1..cb972a9 100644 --- a/loader/pdf_loader.py +++ b/loader/pdf_loader.py @@ -46,7 +46,7 @@ class UnstructuredPaddlePDFLoader(UnstructuredFileLoader): if __name__ == "__main__": - filepath = "../content/samples/test.pdf" + filepath = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content", "samples", "test.pdf") loader = UnstructuredPaddlePDFLoader(filepath, mode="elements") docs = loader.load() for doc in docs: From 3ff885d0d38d6645f55e060504d45dce0c7aca26 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 13 May 2023 21:35:17 +0800 Subject: [PATCH 4/4] update api.py --- api.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/api.py b/api.py index 00d8600..c675543 100644 --- a/api.py +++ b/api.py @@ -181,7 +181,7 @@ async def delete_docs( if os.path.exists(doc_path): os.remove(doc_path) else: - return {"code": 1, "msg": f"document {doc_name} not found"} + BaseResponse(code=1, msg=f"document {doc_name} not found") remain_docs = await list_docs(knowledge_base_id) if remain_docs["code"] != 0 or len(remain_docs["data"]) == 0: @@ -211,24 +211,30 @@ async def local_doc_chat( ): vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) if not os.path.exists(vs_path): - raise ValueError(f"Knowledge base {knowledge_base_id} not found") + # return BaseResponse(code=1, msg=f"Knowledge base {knowledge_base_id} not found") + return ChatMessage( + question=question, + response=f"Knowledge base {knowledge_base_id} not found", + history=history, + source_documents=[], + ) + else: + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] - - return ChatMessage( - question=question, - response=resp["result"], - history=history, - source_documents=source_documents, - ) + return ChatMessage( + question=question, + response=resp["result"], + history=history, + source_documents=source_documents, + ) async def chat(