Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
a543a3d
1
Parent(s):
b69c7d1
加强代码健壮性
Browse files- ChuanhuChatbot.py +9 -9
- utils.py +31 -14
ChuanhuChatbot.py
CHANGED
|
@@ -69,26 +69,26 @@ with gr.Blocks(css=customCSS) as demo:
|
|
| 69 |
with gr.Column():
|
| 70 |
with gr.Row():
|
| 71 |
with gr.Column(scale=6):
|
| 72 |
-
templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False)
|
| 73 |
with gr.Column(scale=1):
|
| 74 |
templateRefreshBtn = gr.Button("🔄 刷新")
|
| 75 |
templaeFileReadBtn = gr.Button("📂 读入模板")
|
| 76 |
with gr.Row():
|
| 77 |
with gr.Column(scale=6):
|
| 78 |
-
templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False)
|
| 79 |
with gr.Column(scale=1):
|
| 80 |
templateApplyBtn = gr.Button("⬇️ 应用")
|
| 81 |
-
with gr.Accordion(label="保存/加载对话历史记录
|
| 82 |
with gr.Column():
|
| 83 |
with gr.Row():
|
| 84 |
with gr.Column(scale=6):
|
| 85 |
saveFileName = gr.Textbox(
|
| 86 |
show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
|
| 87 |
with gr.Column(scale=1):
|
| 88 |
-
|
| 89 |
with gr.Row():
|
| 90 |
with gr.Column(scale=6):
|
| 91 |
-
historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False)
|
| 92 |
with gr.Column(scale=1):
|
| 93 |
historyRefreshBtn = gr.Button("🔄 刷新")
|
| 94 |
historyReadBtn = gr.Button("📂 读入对话")
|
|
@@ -116,14 +116,14 @@ with gr.Blocks(css=customCSS) as demo:
|
|
| 116 |
chatbot, history], show_progress=True)
|
| 117 |
reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
|
| 118 |
systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
|
| 119 |
-
|
| 120 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
| 121 |
-
|
| 122 |
historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
| 123 |
-
historyReadBtn.click(load_chat_history, [historyFileSelectDropdown], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
|
| 124 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
| 125 |
templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
|
| 126 |
-
templateApplyBtn.click(
|
| 127 |
|
| 128 |
print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
|
| 129 |
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
|
|
|
|
| 69 |
with gr.Column():
|
| 70 |
with gr.Row():
|
| 71 |
with gr.Column(scale=6):
|
| 72 |
+
templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)", choices=get_template_names(plain=True), multiselect=False, value=get_template_names(plain=True)[0])
|
| 73 |
with gr.Column(scale=1):
|
| 74 |
templateRefreshBtn = gr.Button("🔄 刷新")
|
| 75 |
templaeFileReadBtn = gr.Button("📂 读入模板")
|
| 76 |
with gr.Row():
|
| 77 |
with gr.Column(scale=6):
|
| 78 |
+
templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=load_template(get_template_names(plain=True)[0], mode=1), multiselect=False, value=load_template(get_template_names(plain=True)[0], mode=1)[0])
|
| 79 |
with gr.Column(scale=1):
|
| 80 |
templateApplyBtn = gr.Button("⬇️ 应用")
|
| 81 |
+
with gr.Accordion(label="保存/加载对话历史记录", open=False):
|
| 82 |
with gr.Column():
|
| 83 |
with gr.Row():
|
| 84 |
with gr.Column(scale=6):
|
| 85 |
saveFileName = gr.Textbox(
|
| 86 |
show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名", value="对话历史记录").style(container=True)
|
| 87 |
with gr.Column(scale=1):
|
| 88 |
+
saveHistoryBtn = gr.Button("💾 保存对话")
|
| 89 |
with gr.Row():
|
| 90 |
with gr.Column(scale=6):
|
| 91 |
+
historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话", choices=get_history_names(plain=True), multiselect=False, value=get_history_names(plain=True)[0])
|
| 92 |
with gr.Column(scale=1):
|
| 93 |
historyRefreshBtn = gr.Button("🔄 刷新")
|
| 94 |
historyReadBtn = gr.Button("📂 读入对话")
|
|
|
|
| 116 |
chatbot, history], show_progress=True)
|
| 117 |
reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
|
| 118 |
systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history, statusDisplay], show_progress=True)
|
| 119 |
+
saveHistoryBtn.click(save_chat_history, [
|
| 120 |
saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
|
| 121 |
+
saveHistoryBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
| 122 |
historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
|
| 123 |
+
historyReadBtn.click(load_chat_history, [historyFileSelectDropdown, systemPromptTxt, history, chatbot], [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
|
| 124 |
templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
|
| 125 |
templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown], show_progress=True)
|
| 126 |
+
templateApplyBtn.click(get_template_content, [promptTemplates, templateSelectDropdown, systemPromptTxt], [systemPromptTxt], show_progress=True)
|
| 127 |
|
| 128 |
print("川虎的温馨提示:访问 http://localhost:7860 查看界面")
|
| 129 |
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
|
utils.py
CHANGED
|
@@ -210,15 +210,18 @@ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[],
|
|
| 210 |
|
| 211 |
|
| 212 |
def delete_last_conversation(chatbot, history):
|
| 213 |
-
|
| 214 |
-
chatbot
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
history.pop()
|
| 219 |
history.pop()
|
|
|
|
| 220 |
print(history)
|
| 221 |
-
|
|
|
|
|
|
|
| 222 |
|
| 223 |
def save_chat_history(filename, system, history, chatbot):
|
| 224 |
if filename == "":
|
|
@@ -232,11 +235,16 @@ def save_chat_history(filename, system, history, chatbot):
|
|
| 232 |
json.dump(json_s, f)
|
| 233 |
|
| 234 |
|
| 235 |
-
def load_chat_history(filename):
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def sorted_by_pinyin(list):
|
| 242 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
|
@@ -250,6 +258,8 @@ def get_file_names(dir, plain=False, filetypes=[".json"]):
|
|
| 250 |
except FileNotFoundError:
|
| 251 |
files = []
|
| 252 |
files = sorted_by_pinyin(files)
|
|
|
|
|
|
|
| 253 |
if plain:
|
| 254 |
return files
|
| 255 |
else:
|
|
@@ -260,6 +270,7 @@ def get_history_names(plain=False):
|
|
| 260 |
|
| 261 |
def load_template(filename, mode=0):
|
| 262 |
lines = []
|
|
|
|
| 263 |
if filename.endswith(".json"):
|
| 264 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
| 265 |
lines = json.load(f)
|
|
@@ -270,19 +281,25 @@ def load_template(filename, mode=0):
|
|
| 270 |
lines = list(reader)
|
| 271 |
lines = lines[1:]
|
| 272 |
if mode == 1:
|
| 273 |
-
return
|
| 274 |
elif mode == 2:
|
| 275 |
return {row[0]:row[1] for row in lines}
|
| 276 |
else:
|
| 277 |
-
|
|
|
|
| 278 |
|
| 279 |
def get_template_names(plain=False):
|
| 280 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
def reset_state():
|
| 283 |
return [], []
|
| 284 |
|
| 285 |
-
|
| 286 |
def compose_system(system_prompt):
|
| 287 |
return {"role": "system", "content": system_prompt}
|
| 288 |
|
|
|
|
| 210 |
|
| 211 |
|
| 212 |
def delete_last_conversation(chatbot, history):
|
| 213 |
+
try:
|
| 214 |
+
if "☹️发生了错误" in chatbot[-1][1]:
|
| 215 |
+
chatbot.pop()
|
| 216 |
+
print(history)
|
| 217 |
+
return chatbot, history
|
| 218 |
history.pop()
|
| 219 |
history.pop()
|
| 220 |
+
chatbot.pop()
|
| 221 |
print(history)
|
| 222 |
+
return chatbot, history
|
| 223 |
+
except:
|
| 224 |
+
return chatbot, history
|
| 225 |
|
| 226 |
def save_chat_history(filename, system, history, chatbot):
|
| 227 |
if filename == "":
|
|
|
|
| 235 |
json.dump(json_s, f)
|
| 236 |
|
| 237 |
|
| 238 |
+
def load_chat_history(filename, system, history, chatbot):
|
| 239 |
+
try:
|
| 240 |
+
print("Loading from history...")
|
| 241 |
+
with open(os.path.join(HISTORY_DIR, filename), "r") as f:
|
| 242 |
+
json_s = json.load(f)
|
| 243 |
+
print(json_s)
|
| 244 |
+
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
| 245 |
+
except FileNotFoundError:
|
| 246 |
+
print("File not found.")
|
| 247 |
+
return filename, system, history, chatbot
|
| 248 |
|
| 249 |
def sorted_by_pinyin(list):
|
| 250 |
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
|
|
|
| 258 |
except FileNotFoundError:
|
| 259 |
files = []
|
| 260 |
files = sorted_by_pinyin(files)
|
| 261 |
+
if files == []:
|
| 262 |
+
files = [""]
|
| 263 |
if plain:
|
| 264 |
return files
|
| 265 |
else:
|
|
|
|
| 270 |
|
| 271 |
def load_template(filename, mode=0):
|
| 272 |
lines = []
|
| 273 |
+
print("Loading template...")
|
| 274 |
if filename.endswith(".json"):
|
| 275 |
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
| 276 |
lines = json.load(f)
|
|
|
|
| 281 |
lines = list(reader)
|
| 282 |
lines = lines[1:]
|
| 283 |
if mode == 1:
|
| 284 |
+
return sorted_by_pinyin([row[0] for row in lines])
|
| 285 |
elif mode == 2:
|
| 286 |
return {row[0]:row[1] for row in lines}
|
| 287 |
else:
|
| 288 |
+
choices = sorted_by_pinyin([row[0] for row in lines])
|
| 289 |
+
return {row[0]:row[1] for row in lines}, gr.Dropdown.update(choices=choices, value=choices[0])
|
| 290 |
|
| 291 |
def get_template_names(plain=False):
|
| 292 |
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
| 293 |
|
| 294 |
+
def get_template_content(templates, selection, original_system_prompt):
|
| 295 |
+
try:
|
| 296 |
+
return templates[selection]
|
| 297 |
+
except:
|
| 298 |
+
return original_system_prompt
|
| 299 |
+
|
| 300 |
def reset_state():
|
| 301 |
return [], []
|
| 302 |
|
|
|
|
| 303 |
def compose_system(system_prompt):
|
| 304 |
return {"role": "system", "content": system_prompt}
|
| 305 |
|