update local_doc_qa.py
This commit is contained in:
parent
c98c4888e9
commit
a4e67a67b4
|
|
@ -82,15 +82,19 @@ def similarity_search_with_score_by_vector(
|
||||||
id_set.add(i)
|
id_set.add(i)
|
||||||
docs_len = len(doc.page_content)
|
docs_len = len(doc.page_content)
|
||||||
for k in range(1, max(i, len(docs) - i)):
|
for k in range(1, max(i, len(docs) - i)):
|
||||||
|
break_flag = False
|
||||||
for l in [i + k, i - k]:
|
for l in [i + k, i - k]:
|
||||||
if 0 <= l < len(self.index_to_docstore_id):
|
if 0 <= l < len(self.index_to_docstore_id):
|
||||||
_id0 = self.index_to_docstore_id[l]
|
_id0 = self.index_to_docstore_id[l]
|
||||||
doc0 = self.docstore.search(_id0)
|
doc0 = self.docstore.search(_id0)
|
||||||
if docs_len + len(doc0.page_content) > self.chunk_size:
|
if docs_len + len(doc0.page_content) > self.chunk_size:
|
||||||
|
break_flag=True
|
||||||
break
|
break
|
||||||
elif doc0.metadata["source"] == doc.metadata["source"]:
|
elif doc0.metadata["source"] == doc.metadata["source"]:
|
||||||
docs_len += len(doc0.page_content)
|
docs_len += len(doc0.page_content)
|
||||||
id_set.add(l)
|
id_set.add(l)
|
||||||
|
if break_flag:
|
||||||
|
break
|
||||||
id_list = sorted(list(id_set))
|
id_list = sorted(list(id_set))
|
||||||
id_lists = seperate_list(id_list)
|
id_lists = seperate_list(id_list)
|
||||||
for id_seq in id_lists:
|
for id_seq in id_lists:
|
||||||
|
|
@ -225,8 +229,8 @@ class LocalDocQA:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg()
|
||||||
query = "你好"
|
query = "本项目使用的embedding模型是什么,消耗多少显存"
|
||||||
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/123"
|
vs_path = "/Users/liuqian/Downloads/glm-dev/vector_store/aaa"
|
||||||
last_print_len = 0
|
last_print_len = 0
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
vs_path=vs_path,
|
vs_path=vs_path,
|
||||||
|
|
@ -234,9 +238,14 @@ if __name__ == "__main__":
|
||||||
streaming=True):
|
streaming=True):
|
||||||
print(resp["result"][last_print_len:], end="", flush=True)
|
print(resp["result"][last_print_len:], end="", flush=True)
|
||||||
last_print_len = len(resp["result"])
|
last_print_len = len(resp["result"])
|
||||||
for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
source_text = [f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
|
||||||
vs_path=vs_path,
|
# f"""相关度:{doc.metadata['score']}\n\n"""
|
||||||
chat_history=[],
|
for inum, doc in
|
||||||
streaming=False):
|
enumerate(resp["source_documents"])]
|
||||||
print(resp["result"])
|
print("\n\n" + "\n\n".join(source_text))
|
||||||
|
# for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
||||||
|
# vs_path=vs_path,
|
||||||
|
# chat_history=[],
|
||||||
|
# streaming=False):
|
||||||
|
# print(resp["result"])
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue