From a4e67a67b41fd38b9694cf4b31309ebcfc0879e9 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Tue, 2 May 2023 01:11:05 +0800 Subject: [PATCH] update local_doc_qa.py --- chains/local_doc_qa.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index b3eca15..6b3d1e2 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -82,15 +82,19 @@ def similarity_search_with_score_by_vector( id_set.add(i) docs_len = len(doc.page_content) for k in range(1, max(i, len(docs) - i)): + break_flag = False for l in [i + k, i - k]: if 0 <= l < len(self.index_to_docstore_id): _id0 = self.index_to_docstore_id[l] doc0 = self.docstore.search(_id0) if docs_len + len(doc0.page_content) > self.chunk_size: + break_flag=True break elif doc0.metadata["source"] == doc.metadata["source"]: docs_len += len(doc0.page_content) id_set.add(l) + if break_flag: + break id_list = sorted(list(id_set)) id_lists = seperate_list(id_list) for id_seq in id_lists: @@ -225,8 +229,8 @@ class LocalDocQA: if __name__ == "__main__": local_doc_qa = LocalDocQA() local_doc_qa.init_cfg() - query = "你好" - vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123" + query = "本项目使用的embedding模型是什么,消耗多少显存" + vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa" last_print_len = 0 for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, vs_path=vs_path, @@ -234,9 +238,14 @@ if __name__ == "__main__": streaming=True): print(resp["result"][last_print_len:], end="", flush=True) last_print_len = len(resp["result"]) - for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, - vs_path=vs_path, - chat_history=[], - streaming=False): - print(resp["result"]) + source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + # f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in + enumerate(resp["source_documents"])] + print("\n\n" + "\n\n".join(source_text)) + # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query, + # vs_path=vs_path, + # chat_history=[], + # streaming=False): + # print(resp["result"]) pass