Spaces:
Sleeping
Sleeping
功能优化: 添加双栏pdf识别选项到页面,并优化config文件中关于文档解析的设置
Browse files- ChuanhuChatbot.py +6 -0
- config_example.json +5 -2
- modules/chat_func.py +8 -6
- modules/config.py +12 -2
- modules/llama_func.py +8 -6
ChuanhuChatbot.py
CHANGED
|
@@ -78,6 +78,10 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
| 78 |
value=REPLY_LANGUAGES[0],
|
| 79 |
)
|
| 80 |
index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
with gr.Tab(label="Prompt"):
|
| 83 |
systemPromptTxt = gr.Textbox(
|
|
@@ -295,6 +299,8 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
| 295 |
)
|
| 296 |
reduceTokenBtn.click(**get_usage_args)
|
| 297 |
|
|
|
|
|
|
|
| 298 |
# ChatGPT
|
| 299 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
| 300 |
keyTxt.submit(**get_usage_args)
|
|
|
|
| 78 |
value=REPLY_LANGUAGES[0],
|
| 79 |
)
|
| 80 |
index_files = gr.Files(label="上传索引文件", type="file", multiple=True)
|
| 81 |
+
two_column = gr.Checkbox(label="双栏pdf", value=advance_docs["pdf"].get("two_column", False))
|
| 82 |
+
# TODO: 公式ocr
|
| 83 |
+
# formula_ocr = gr.Checkbox(label="识别公式", value=advance_docs["pdf"].get("formula_ocr", False))
|
| 84 |
+
updateDocConfigBtn = gr.Button("更新解析文件参数")
|
| 85 |
|
| 86 |
with gr.Tab(label="Prompt"):
|
| 87 |
systemPromptTxt = gr.Textbox(
|
|
|
|
| 299 |
)
|
| 300 |
reduceTokenBtn.click(**get_usage_args)
|
| 301 |
|
| 302 |
+
updateDocConfigBtn.click(update_doc_config, [two_column], None)
|
| 303 |
+
|
| 304 |
# ChatGPT
|
| 305 |
keyTxt.change(submit_key, keyTxt, [user_api_key, status_display]).then(**get_usage_args)
|
| 306 |
keyTxt.submit(**get_usage_args)
|
config_example.json
CHANGED
|
@@ -2,8 +2,11 @@
|
|
| 2 |
"openai_api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxx",
|
| 3 |
"https_proxy": "http://127.0.0.1:1079",
|
| 4 |
"http_proxy": "http://127.0.0.1:1079",
|
| 5 |
-
"
|
| 6 |
-
"
|
|
|
|
|
|
|
|
|
|
| 7 |
},
|
| 8 |
"users": [
|
| 9 |
["root", "root"]
|
|
|
|
| 2 |
"openai_api_key": "sk-xxxxxxxxxxxxxxxxxxxxxxxxx",
|
| 3 |
"https_proxy": "http://127.0.0.1:1079",
|
| 4 |
"http_proxy": "http://127.0.0.1:1079",
|
| 5 |
+
"advance_docs": {
|
| 6 |
+
"pdf": {
|
| 7 |
+
"two_column": true,
|
| 8 |
+
"formula_ocr": true
|
| 9 |
+
}
|
| 10 |
},
|
| 11 |
"users": [
|
| 12 |
["root", "root"]
|
modules/chat_func.py
CHANGED
|
@@ -291,12 +291,14 @@ def predict(
|
|
| 291 |
msg = "索引构建完成,获取回答中……"
|
| 292 |
logging.info(msg)
|
| 293 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
| 300 |
reference_results = [n.node.text for n in nodes]
|
| 301 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
| 302 |
display_reference = add_details(reference_results)
|
|
|
|
| 291 |
msg = "索引构建完成,获取回答中……"
|
| 292 |
logging.info(msg)
|
| 293 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 294 |
+
with retrieve_proxy():
|
| 295 |
+
llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
|
| 296 |
+
prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
|
| 297 |
+
from llama_index import ServiceContext
|
| 298 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
|
| 299 |
+
query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
|
| 300 |
+
query_bundle = QueryBundle(inputs)
|
| 301 |
+
nodes = query_object.retrieve(query_bundle)
|
| 302 |
reference_results = [n.node.text for n in nodes]
|
| 303 |
reference_results = add_source_numbers(reference_results, use_source=False)
|
| 304 |
display_reference = add_details(reference_results)
|
modules/config.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from contextlib import contextmanager
|
| 2 |
import os
|
| 3 |
import logging
|
|
@@ -11,6 +12,8 @@ __all__ = [
|
|
| 11 |
"dockerflag",
|
| 12 |
"retrieve_proxy",
|
| 13 |
"log_level",
|
|
|
|
|
|
|
| 14 |
]
|
| 15 |
|
| 16 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
|
@@ -109,5 +112,12 @@ def retrieve_proxy(proxy=None):
|
|
| 109 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
| 110 |
|
| 111 |
|
| 112 |
-
## 处理advance
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
from contextlib import contextmanager
|
| 3 |
import os
|
| 4 |
import logging
|
|
|
|
| 12 |
"dockerflag",
|
| 13 |
"retrieve_proxy",
|
| 14 |
"log_level",
|
| 15 |
+
"advance_docs",
|
| 16 |
+
"update_doc_config"
|
| 17 |
]
|
| 18 |
|
| 19 |
# 添加一个统一的config文件,避免文件过多造成的疑惑(优先级最低)
|
|
|
|
| 112 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
| 113 |
|
| 114 |
|
| 115 |
+
## 处理advance docs
|
| 116 |
+
advance_docs = defaultdict(lambda: defaultdict(dict))
|
| 117 |
+
advance_docs.update(config.get("advance_docs", {}))
|
| 118 |
+
def update_doc_config(two_column_pdf):
|
| 119 |
+
global advance_docs
|
| 120 |
+
if two_column_pdf:
|
| 121 |
+
advance_docs["pdf"]["two_column"] = True
|
| 122 |
+
|
| 123 |
+
logging.info(f"更新后的文件参数为:{advance_docs}")
|
modules/llama_func.py
CHANGED
|
@@ -45,8 +45,9 @@ def get_documents(file_src):
|
|
| 45 |
logging.debug("Loading PDF...")
|
| 46 |
try:
|
| 47 |
from modules.pdf_func import parse_pdf
|
| 48 |
-
from modules.config import
|
| 49 |
-
|
|
|
|
| 50 |
except:
|
| 51 |
pdftext = ""
|
| 52 |
with open(file.name, 'rb') as pdfFileObj:
|
|
@@ -106,10 +107,11 @@ def construct_index(
|
|
| 106 |
try:
|
| 107 |
documents = get_documents(file_src)
|
| 108 |
logging.info("构建索引中……")
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 113 |
logging.debug("索引构建完成!")
|
| 114 |
os.makedirs("./index", exist_ok=True)
|
| 115 |
index.save_to_disk(f"./index/{index_name}.json")
|
|
|
|
| 45 |
logging.debug("Loading PDF...")
|
| 46 |
try:
|
| 47 |
from modules.pdf_func import parse_pdf
|
| 48 |
+
from modules.config import advance_docs
|
| 49 |
+
two_column = advance_docs["pdf"].get("two_column", False)
|
| 50 |
+
pdftext = parse_pdf(file.name, two_column).text
|
| 51 |
except:
|
| 52 |
pdftext = ""
|
| 53 |
with open(file.name, 'rb') as pdfFileObj:
|
|
|
|
| 107 |
try:
|
| 108 |
documents = get_documents(file_src)
|
| 109 |
logging.info("构建索引中……")
|
| 110 |
+
with retrieve_proxy():
|
| 111 |
+
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
|
| 112 |
+
index = GPTSimpleVectorIndex.from_documents(
|
| 113 |
+
documents, service_context=service_context
|
| 114 |
+
)
|
| 115 |
logging.debug("索引构建完成!")
|
| 116 |
os.makedirs("./index", exist_ok=True)
|
| 117 |
index.save_to_disk(f"./index/{index_name}.json")
|