From 47922d2ee311b8ce8d9195152008828cebfdcc7f Mon Sep 17 00:00:00 2001 From: Winter Date: Thu, 4 May 2023 20:58:15 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=89=A9=E5=B1=95=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E7=9A=84=E4=BB=A3=E7=A0=81=E9=80=BB=E8=BE=91=20(#227)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 扩展文档的代码逻辑 * Update local_doc_qa.py --------- Co-authored-by: imClumsyPanda --- chains/local_doc_qa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 6176bc2..866bedf 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -68,6 +68,7 @@ def similarity_search_with_score_by_vector( scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) docs = [] id_set = set() + store_len = len(self.index_to_docstore_id) for j, i in enumerate(indices[0]): if i == -1: # This happens when not enough docs are returned. @@ -76,7 +77,7 @@ def similarity_search_with_score_by_vector( doc = self.docstore.search(_id) id_set.add(i) docs_len = len(doc.page_content) - for k in range(1, max(i, len(docs) - i)): + for k in range(1, max(i, store_len-i)): break_flag = False for l in [i + k, i - k]: if 0 <= l < len(self.index_to_docstore_id):