MarshallCN commited on
Commit
bd5ce6f
·
1 Parent(s): b943737
Files changed (6) hide show
  1. .gitignore +15 -0
  2. Chat_RAG_vecDB.py +212 -0
  3. README.md +1 -1
  4. ggufv2.py +412 -0
  5. requirements.txt +4 -0
  6. utils.py +151 -0
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter
2
+ **/.ipynb_checkpoints/
3
+ .ipynb_* # any hidden Jupyter aux files like .ipynb_foo
4
+
5
+ # Python cache/bytecode
6
+ **/__pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ /old/
10
+ /old/*
11
+ models/
12
+ models/*
13
+ export/
14
+ msgs/
15
+ msgs/*
Chat_RAG_vecDB.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import stat
3
+ import os, shutil, pickle, torch, json, hashlib
4
+ import faiss, numpy as np
5
+ from pathlib import Path
6
+ from FlagEmbedding import BGEM3FlagModel
7
+ from sentence_transformers import CrossEncoder
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_huggingface import HuggingFaceEmbeddings
12
+ from utils import mk_msg_dir
13
+
14
+ # === 模型加载 ===
15
+ if gr.NO_RELOAD:
16
+ BASE_DIR = r"C:\Users\c1052689\hug_models\Qwen2.5-0.5B-Instruct"
17
+ tok = AutoTokenizer.from_pretrained(BASE_DIR, use_fast=False, local_files_only=True)
18
+ bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
19
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
20
+ model = AutoModelForCausalLM.from_pretrained(BASE_DIR, quantization_config=bnb, device_map="auto", local_files_only=True)
21
+ pipe = pipeline("text-generation", model=model, tokenizer=tok, max_new_tokens=512)
22
+ BGEM3 = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
23
+ reranker = CrossEncoder("BAAI/bge-reranker-large")
24
+
25
+ # === 向量库全局变量 ===
26
+ corpus = []
27
+ index = None
28
+ current_db_dir = None
29
+
30
+ vec_dir_base = './vectorstore/bgem3/'
31
+ embedding_model_id = 'BAAI/bge-m3'
32
+
33
+ # === 文档加载 & 向量构建 ===
34
+ def load_documents(folder: str):
35
+ docs = []
36
+ for path in Path(folder).rglob("*"):
37
+ if path.suffix == ".txt":
38
+ docs += TextLoader(str(path), encoding="utf-8").load()
39
+ elif path.suffix == ".pdf":
40
+ docs += PyPDFLoader(str(path)).load()
41
+ return docs
42
+
43
+ def split_docs(docs, chunk_size=512, chunk_overlap=64):
44
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
45
+ return splitter.split_documents(docs)
46
+
47
+ def on_rm_error(func, path, exc_info):
48
+ os.chmod(path, stat.S_IWRITE)
49
+ func(path)
50
+
51
+ def hash_text(text):
52
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
53
+
54
+ def list_vector_dbs():
55
+ db_list = [f.name for f in Path(vec_dir_base).iterdir() if f.is_dir()]
56
+ return ["<New Vector DB>"] + db_list
57
+
58
+ def create_or_extend_index(docs, selected_db):
59
+ global corpus, index, current_db_dir
60
+
61
+ temp_dir = "temp_docs"
62
+ if os.path.exists(temp_dir):
63
+ shutil.rmtree(temp_dir, onerror=on_rm_error)
64
+ os.makedirs(temp_dir, exist_ok=True)
65
+
66
+ for file in docs:
67
+ src_path = file.name if hasattr(file, "name") else str(file)
68
+ dst_path = os.path.join(temp_dir, os.path.basename(src_path))
69
+ shutil.copy(src_path, dst_path)
70
+
71
+ raw_docs = load_documents(temp_dir)
72
+ chunks = split_docs(raw_docs)
73
+ new_corpus = [t.page_content for t in chunks]
74
+ new_hashes = [hash_text(t) for t in new_corpus]
75
+
76
+ if selected_db == "<New Vector DB>":
77
+ db_id = mk_msg_dir(Path(vec_dir_base))
78
+ current_db_dir = os.path.join(vec_dir_base, db_id)
79
+ os.makedirs(current_db_dir, exist_ok=True)
80
+ index = faiss.IndexFlatIP(BGEM3.encode(["test"])["dense_vecs"].shape[1])
81
+ corpus = []
82
+ existing_hashes = set()
83
+ else:
84
+ current_db_dir = os.path.join(vec_dir_base, selected_db)
85
+ index = faiss.read_index(os.path.join(current_db_dir, "index.faiss"))
86
+ with open(os.path.join(current_db_dir, "corpus.pkl"), "rb") as f:
87
+ corpus = pickle.load(f)
88
+ with open(os.path.join(current_db_dir, "meta.json"), "r", encoding="utf-8") as f:
89
+ meta = json.load(f)
90
+ existing_hashes = set(meta.get("hashes", []))
91
+
92
+ # 去重
93
+ filtered = [(c, h) for c, h in zip(new_corpus, new_hashes) if h not in existing_hashes]
94
+ if not filtered:
95
+ return "✅ No new (non-duplicate) chunks to add."
96
+
97
+ add_corpus, add_hashes = zip(*filtered)
98
+ dense = BGEM3.encode(add_corpus, batch_size=64)["dense_vecs"]
99
+ if isinstance(dense, torch.Tensor):
100
+ dense = dense.detach().cpu().numpy()
101
+ dense = np.ascontiguousarray(dense, dtype=np.float32)
102
+ faiss.normalize_L2(dense)
103
+
104
+ index.add(dense)
105
+ corpus.extend(add_corpus)
106
+ all_hashes = list(existing_hashes) + list(add_hashes)
107
+
108
+ faiss.write_index(index, os.path.join(current_db_dir, "index.faiss"))
109
+ with open(os.path.join(current_db_dir, "corpus.pkl"), "wb") as f:
110
+ pickle.dump(corpus, f)
111
+ meta = {
112
+ "model": embedding_model_id,
113
+ "dim": int(dense.shape[1]),
114
+ "total_chunks": len(corpus),
115
+ "raw_docs": len(raw_docs),
116
+ "hashes": all_hashes,
117
+ }
118
+ with open(os.path.join(current_db_dir, "meta.json"), "w", encoding="utf-8") as f:
119
+ json.dump(meta, f, indent=2)
120
+ db_stats = f"✅ Added {len(add_corpus)} new chunks to DB `{os.path.basename(current_db_dir)}`."
121
+ db_list_update = gr.update(choices=list_vector_dbs())
122
+ return db_stats, db_list_update
123
+
124
+ def build_prompt_corpus(top_docs, question):
125
+ context_text = "\n\n".join(top_docs)
126
+ user_prompt = f"""Answer the question based on the following context:
127
+
128
+ {context_text}
129
+
130
+ Question: {question}
131
+ Answer:"""
132
+
133
+ full_prompt = (
134
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
135
+ f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
136
+ "<|im_start|>assistant\n"
137
+ )
138
+ return full_prompt
139
+
140
+ def ask_question(query):
141
+ if not query.strip():
142
+ return "❌ Please enter your questions", [], ""
143
+ if index is None or len(corpus) == 0:
144
+ return "⚠️ Please upload your documents", [], ""
145
+
146
+ qv = np.array(BGEM3.encode([query])["dense_vecs"], dtype="float32")
147
+ faiss.normalize_L2(qv)
148
+ D, I = index.search(qv, 8)
149
+ results = [corpus[i] for i in I[0]]
150
+ pairs = [[query, c] for c in results]
151
+ scores = reranker.predict(pairs)
152
+ top_docs = [c for _, c in sorted(zip(scores, results), reverse=True)][:3]
153
+
154
+ prompt = build_prompt_corpus(top_docs, query)
155
+ out = pipe(
156
+ prompt,
157
+ max_new_tokens=1024,
158
+ eos_token_id=tok.eos_token_id,
159
+ pad_token_id=tok.eos_token_id,
160
+ return_full_text=False,
161
+ )
162
+ reply = out[0]["generated_text"]
163
+ context_display = "\n\n".join(
164
+ f"[{i+1}] {doc.strip()[:1000]}" for i, doc in enumerate(top_docs)
165
+ )
166
+ return reply.strip(), context_display
167
+
168
+ def show_db_stats(selected_db):
169
+ if selected_db == "<New Vector DB>":
170
+ return "🆕 New vector DB will be created on next upload."
171
+ try:
172
+ db_dir = os.path.join(vec_dir_base, selected_db)
173
+ with open(os.path.join(db_dir, "meta.json"), "r", encoding="utf-8") as f:
174
+ meta = json.load(f)
175
+ chunk_num = int(meta.get("total_chunks", 0))
176
+ docs_num = int(meta.get("raw_docs", 0))
177
+ return f"📊 DB `{selected_db}`: {docs_num} docs, {chunk_num} chunks"
178
+ except Exception as e:
179
+ return f"⚠️ Failed to load DB `{selected_db}`: {e}"
180
+
181
+ with gr.Blocks(title="Qwen2.5 RAG Chat") as demo:
182
+ gr.Markdown("## 🧠 Qwen2.5 BGEM3-RAG QA")
183
+
184
+ with gr.Row():
185
+ with gr.Column():
186
+ file_box = gr.File(label="Upload documents (PDF or TXT)", file_types=[".pdf", ".txt"], file_count="multiple")
187
+ db_selector = gr.Dropdown(label="Select or create vector DB", choices=list_vector_dbs(), value="<New Vector DB>")
188
+ upload_btn = gr.Button("📚 Add to Vector DB")
189
+ status = gr.Textbox(label="Status")
190
+
191
+ with gr.Column():
192
+ query = gr.Textbox(label="Enter your questions")
193
+ ask_btn = gr.Button("Send")
194
+ answer = gr.Textbox(label="🧠 Answer", lines=5)
195
+ context = gr.Textbox(
196
+ label="📄 Top-3 Reference Contexts",
197
+ lines=10,
198
+ interactive=False,
199
+ show_copy_button=True,
200
+ max_lines=20
201
+ )
202
+ db_selector.change(fn=show_db_stats, inputs=db_selector, outputs=status)
203
+ upload_btn.click(fn=create_or_extend_index, inputs=[file_box, db_selector], outputs=[status, db_selector])
204
+ ask_btn.click(fn=ask_question, inputs=query, outputs=[answer, context])
205
+ demo.load(
206
+ fn=lambda: gr.update(choices=list_vector_dbs()),
207
+ inputs=None,
208
+ outputs=db_selector
209
+ )
210
+
211
+ if __name__ == '__main__':
212
+ demo.launch(debug=True)
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Qwen2.5-0.5B-Q4 RAG demo
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
+ app_file: ggufv2.py
9
  pinned: false
10
  license: mit
11
  short_description: Qwen2.5-0.5B-Q4 RAG demo
ggufv2.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gguf.py — Qwen GGUF chat with multi-session (load/save) via utils.py
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ from datetime import datetime
6
+ from typing import List, Dict, Optional, Tuple
7
+ import shutil
8
+ import gradio as gr
9
+ from llama_cpp import Llama
10
+
11
+ # Multi-session helpers from utils.py
12
+ from utils import mk_msg_dir, _as_dir, persist_messages, trim_by_tokens
13
+ # ===================== Model =====================
14
+ # You can swap to another GGUF by changing repo_id/filename.
15
+ model = Llama.from_pretrained(
16
+ repo_id="bartowski/Qwen2.5-0.5B-Instruct-GGUF",
17
+ filename="Qwen2.5-0.5B-Instruct-Q4_K_M.gguf",
18
+ )
19
+
20
+ assistant_name = "Nova"
21
+ user_name = "Marshall"
22
+ persona = f"""Your name is {assistant_name}. Address the user as "{user_name}". Use Markdown; put code in fenced blocks with a language tag.""".strip()
23
+
24
+ # Where each conversation (session) persists its messages
25
+ BASE_MSG_DIR = Path("./msgs/msgs_QwenGGUF")
26
+ BASE_MSG_DIR.mkdir(parents=True, exist_ok=True)
27
+
28
+ # ---------- Qwen chat template (no tools) ----------
29
+ # def render_qwen(messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
30
+ # """
31
+ # Convert OpenAI-style messages to Qwen2.5 Instruct format:
32
+ # <|im_start|>system ... <|im_end|>
33
+ # <|im_start|>user ... <|im_end|>
34
+ # <|im_start|>assistant (generation continues here)
35
+ # """
36
+ # # System prompt
37
+ # if messages and messages[0].get("role") == "system":
38
+ # sys_txt = messages[0]["content"]
39
+ # rest = messages[1:]
40
+ # else:
41
+ # sys_txt = persona
42
+ # rest = messages
43
+
44
+ # parts = [f"<|im_start|>system\n{sys_txt}<|im_end|>\n"]
45
+ # for m in rest:
46
+ # role = m.get("role")
47
+ # if role not in ("user", "assistant"):
48
+ # continue
49
+ # parts.append(f"<|im_start|>{role}\n{m['content']}<|im_end|>\n")
50
+
51
+ # if add_generation_prompt:
52
+ # parts.append("<|im_start|>assistant\n")
53
+ # return "".join(parts)
54
+
55
+ def render_qwen_trim(
56
+ messages: List[Dict[str, str]],
57
+ model, # llama_cpp.Llama 实例(用于 token 计数)
58
+ n_ctx: Optional[int] = None, # 不传则用 model.n_ctx()
59
+ add_generation_prompt: bool = True,
60
+ persona: str = "",
61
+ reserve_new: int = 256, # 希望生成的新 token 预算(上限)
62
+ pad: int = 8, # 保险余量,避免越界
63
+ hard_user_tail_chars: int = 2000, # 还不够时,最后一条 user 文本的硬截断字符数
64
+ ) -> Tuple[str, int]:
65
+ """
66
+ - 只保留 system + 最近的若干轮对话,使得 total_tokens + reserve_new + pad <= n_ctx
67
+ - 若仍不够,则截短最后一条 user。
68
+ - 返回 (prompt, safe_max_new),safe_max_new 已确保不越界。
69
+ """
70
+ def _tok_len(txt: str) -> int:
71
+ # 与 llama_cpp 的计数保持一致
72
+ return len(model.tokenize(txt.encode("utf-8"), add_bos=True))
73
+
74
+ if n_ctx is None:
75
+ n_ctx = getattr(model, "n_ctx")() if callable(getattr(model, "n_ctx", None)) else model.n_ctx
76
+
77
+ # 1) 拆出 system 与其余消息
78
+ if messages and messages[0].get("role") == "system":
79
+ sys_txt = messages[0]["content"]
80
+ rest = messages[1:]
81
+ else:
82
+ sys_txt = persona
83
+ rest = messages
84
+
85
+ # 仅保留 user / assistant
86
+ rest = [m for m in rest if m.get("role") in ("user", "assistant")]
87
+
88
+ # 2) 生成函数:把 system + 若干轮对话渲染为 Qwen prompt
89
+ def _render(sys_text: str, turns: List[Dict[str, str]], add_gen: bool) -> str:
90
+ parts = [f"<|im_start|>system\n{sys_text}<|im_end|>\n"]
91
+ for m in turns:
92
+ parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n")
93
+ if add_gen:
94
+ parts.append("<|im_start|>assistant\n")
95
+ return "".join(parts)
96
+
97
+ # 3) 先尝试保留全部轮次,从最老开始裁剪直到 fits
98
+ kept = rest[:] # 深拷贝
99
+ while True:
100
+ prompt = _render(sys_txt, kept, add_generation_prompt)
101
+ used = _tok_len(prompt)
102
+
103
+ # 计算还能安全生成的 token 数
104
+ safe_max_new = max(1, n_ctx - used - pad)
105
+ # 希望生成 reserve_new,但不能超过 safe_max_new
106
+ if used + reserve_new + pad <= n_ctx:
107
+ # 有余量,按 reserve_new 返回可生成上限
108
+ return prompt, min(reserve_new, safe_max_new)
109
+
110
+ # 没有余量——需要裁剪历史。如果可裁剪的 turns < 1,则进入硬截断
111
+ if len(kept) <= 1:
112
+ break # 只剩最后一条,跳出去做硬截断
113
+
114
+ # 从最早的一条开始丢;为避免打断成对语义,可一次丢两条(user+assistant)
115
+ # 但如果开头不是成对,就按 1 条丢弃。
116
+ drop_count = 2 if len(kept) >= 2 else 1
117
+ # 保证留下至少 1 条(最后一条 user)用于上下文
118
+ while drop_count > 0 and len(kept) > 1:
119
+ kept.pop(0)
120
+ drop_count -= 1
121
+
122
+ # 4) 仍然不够:硬截断“最后一条 user”文本尾部
123
+ # 目标:尽量保留最近语义,同时立刻释放 token 空间
124
+ if kept and kept[-1]["role"] == "user":
125
+ kept[-1] = {
126
+ "role": "user",
127
+ "content": kept[-1]["content"][-hard_user_tail_chars:]
128
+ }
129
+ elif kept:
130
+ # 最后一条不是 user,则尽量截短它(通常是 assistant)
131
+ kept[-1] = {
132
+ "role": kept[-1]["role"],
133
+ "content": kept[-1]["content"][-hard_user_tail_chars:]
134
+ }
135
+
136
+ # 重新渲染并最终给出安全 max_new
137
+ prompt = _render(sys_txt, kept, add_generation_prompt)
138
+ used = _tok_len(prompt)
139
+ safe_max_new = max(1, n_ctx - used - pad)
140
+
141
+ # 如果仍然超(极端长的 system),进一步把 system 也截短
142
+ if used + pad > n_ctx:
143
+ trimmed_sys = sys_txt[-hard_user_tail_chars:]
144
+ prompt = _render(trimmed_sys, kept, add_generation_prompt)
145
+ used = _tok_len(prompt)
146
+ safe_max_new = max(1, n_ctx - used - pad)
147
+
148
+ # 不允许返回负或 0
149
+ return prompt, max(1, safe_max_new)
150
+
151
+
152
+ STOP_TOKENS = ["<|im_end|>", "<|endoftext|>"]
153
+
154
+ # ---------- Helpers for system + display ----------
155
+ def ensure_system(messages: Optional[List[Dict[str, str]]], sys_prompt: str) -> List[Dict[str, str]]:
156
+ """Guarantee a system message at index 0 and keep it in sync with the UI textbox."""
157
+ sys_prompt = (sys_prompt or persona).strip()
158
+ if not messages or messages[0].get("role") != "system":
159
+ return [{"role": "system", "content": sys_prompt}]
160
+ messages = list(messages)
161
+ messages[0] = {"role": "system", "content": sys_prompt}
162
+ return messages
163
+
164
+
165
+ def visible_chat(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
166
+ """Hide system from chat display for gr.Chatbot(type='messages')."""
167
+ return [m for m in (messages or []) if m.get("role") in ("user", "assistant")]
168
+
169
+
170
+ # ---------- Session I/O ----------
171
+ def _load_latest(msg_id: str) -> List[Dict[str, str]]:
172
+ p = Path(_as_dir(BASE_MSG_DIR, msg_id), "trimmed.json")
173
+ if p.exists():
174
+ try:
175
+ return json.loads(p.read_text(encoding="utf-8"))
176
+ except Exception:
177
+ return []
178
+ return []
179
+
180
+
181
+ def _init_sessions():
182
+ sessions = [p.name for p in BASE_MSG_DIR.iterdir() if p.is_dir()] if BASE_MSG_DIR.exists() else []
183
+ if len(sessions) == 0:
184
+ # No history
185
+ return gr.update(choices=[], value=None), [], "", [], []
186
+ sessions.sort(reverse=True)
187
+ msg_id = sessions[0]
188
+ messages = _load_latest(msg_id)
189
+ chat_hist = visible_chat(messages)
190
+ return gr.update(choices=sessions, value=msg_id), sessions, msg_id, messages, chat_hist
191
+
192
+
193
+ def load_session(session_list, sessions):
194
+ msg_id = session_list
195
+ messages = _load_latest(msg_id)
196
+ chat_hist = visible_chat(messages)
197
+ return msg_id, messages, chat_hist, gr.update(choices=sessions, value=msg_id)
198
+
199
+
200
+ def start_new_session(sessions):
201
+ msg_id = mk_msg_dir(BASE_MSG_DIR)
202
+ sessions = list(sessions or []) + [msg_id]
203
+ return [], [], "", msg_id, gr.update(choices=sessions, value=msg_id), sessions
204
+
205
+
206
+ def _on_rm_error(func, path, exc_info):
207
+ try:
208
+ if os.name == "nt": # Windows
209
+ os.chmod(path, stat.S_IWRITE) # 去掉只读
210
+ else: # Linux / macOS
211
+ mode = os.stat(path).st_mode
212
+ os.chmod(
213
+ path,
214
+ mode | stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR # 给所有者加 r,w,x
215
+ )
216
+ func(path) # 重试原操作,如 os.remove 或 os.rmdir
217
+ except Exception:
218
+ pass
219
+
220
+ def delete_session(msg_id, sessions):
221
+ """Delete the currently selected session directory and refresh the list."""
222
+ # Remove directory for current session
223
+ if msg_id:
224
+ try:
225
+ shutil.rmtree(_as_dir(BASE_MSG_DIR, msg_id), onerror=_on_rm_error)
226
+ except Exception:
227
+ shutil.rmtree(_as_dir(BASE_MSG_DIR, msg_id), ignore_errors=True)
228
+ # Re-scan sessions on disk
229
+ if BASE_MSG_DIR.exists():
230
+ sess = [p.name for p in BASE_MSG_DIR.iterdir() if p.is_dir()]
231
+ else:
232
+ sess = []
233
+ sess.sort(reverse=True)
234
+ if sess:
235
+ new_id = sess[0]
236
+ msgs = _load_latest(new_id)
237
+ chat_hist = visible_chat(msgs)
238
+ return msgs, chat_hist, "", new_id, gr.update(choices=sess, value=new_id), sess
239
+ else:
240
+ return [], [], "", "", gr.update(choices=[], value=None), []
241
+
242
+ def export_messages_to_json(messages, msg_id):
243
+ base = Path("/data/exports") if Path("/data").exists() else Path("./exports")
244
+ base.mkdir(parents=True, exist_ok=True)
245
+ stamp = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
246
+ fname = f"chat_{stamp}.json"
247
+ path = base / fname
248
+ path.write_text(json.dumps(messages or [], ensure_ascii=False, indent=2), encoding="utf-8")
249
+ return str(path)
250
+
251
+ def on_click_download(messages, msg_id):
252
+ path = export_messages_to_json(messages, msg_id)
253
+ return gr.update(value=path, visible=True)
254
+
255
+
256
+ # ---------- Generation callback ----------
257
+ def on_send(user_text: str,
258
+ messages: List[Dict[str, str]],
259
+ msg_id: str,
260
+ sessions: List[str],
261
+ sys_prompt: str,
262
+ temperature: float,
263
+ top_p: float,
264
+ max_new_tokens: int,
265
+ repetition_penalty: float):
266
+ user_text = (user_text or "").strip()
267
+ if not user_text:
268
+ return gr.update(), messages, visible_chat(messages), msg_id, gr.update(choices=sessions, value=(msg_id or None)), sessions
269
+
270
+ # 1) ensure system
271
+ messages = ensure_system(messages, sys_prompt)
272
+
273
+ # 2) session bookkeeping
274
+ new_session = (len(messages) <= 1) # only system exists
275
+ if new_session and not msg_id:
276
+ msg_id = mk_msg_dir(BASE_MSG_DIR)
277
+ sessions = list(sessions or []) + [msg_id]
278
+ if msg_id and msg_id not in (sessions or []):
279
+ sessions = list(sessions or []) + [msg_id]
280
+ sessions_update = gr.update(choices=sessions, value=msg_id)
281
+
282
+ # 3) append user, render, generate
283
+ messages = messages + [{"role": "user", "content": user_text}]
284
+ # prompt = render_qwen(messages, add_generation_prompt=True)
285
+ prompt, max_new = render_qwen_trim(
286
+ messages=messages,
287
+ model=model, # llama_cpp.Llama 实例
288
+ n_ctx=None, # 不传用 model.n_ctx()
289
+ add_generation_prompt=True,
290
+ persona=persona, # 你之前的 persona 变量
291
+ reserve_new=max_new_tokens, # 你希望的生成长度
292
+ pad=16
293
+ )
294
+
295
+
296
+ try:
297
+ result = model.create_completion(
298
+ prompt=prompt,
299
+ temperature=float(temperature),
300
+ top_p=float(top_p),
301
+ max_tokens=int(max_new),
302
+ repeat_penalty=float(repetition_penalty),
303
+ stop=STOP_TOKENS,
304
+ )
305
+ reply = result['choices'][0]['text'].strip()
306
+ except Exception:
307
+ _out = model(
308
+ prompt,
309
+ temperature=float(temperature),
310
+ top_p=float(top_p),
311
+ max_tokens=int(max_new),
312
+ repeat_penalty=float(repetition_penalty),
313
+ stop=STOP_TOKENS,
314
+ )
315
+ if isinstance(_out, dict):
316
+ reply = _out.get('choices', [{}])[0].get('text', '').strip()
317
+ else:
318
+ reply = str(_out).strip()
319
+
320
+ # 4) append assistant + persist
321
+ messages = messages + [{"role": "assistant", "content": reply}]
322
+
323
+ if msg_id:
324
+ msg_dir = _as_dir(BASE_MSG_DIR, msg_id)
325
+ persist_messages(messages, msg_dir, archive_last_turn=True)
326
+
327
+ return "", messages, visible_chat(messages), msg_id, sessions_update, sessions
328
+
329
+
330
+ # ===================== UI =====================
331
+ with gr.Blocks(title="Qwen GGUF — multi-session") as demo:
332
+ gr.Markdown("## 🧠 Qwen Chat")
333
+
334
+ with gr.Row():
335
+ with gr.Column(scale=3):
336
+ sys_prompt = gr.Textbox(
337
+ label="System prompt",
338
+ value=persona,
339
+ lines=6,
340
+ show_label=True,
341
+ )
342
+ with gr.Accordion("Generation settings", open=False):
343
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
344
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
345
+ max_new_tokens = gr.Slider(16, 512, value=256, step=16, label="max_new_tokens")
346
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.07, step=0.01, label="repetition_penalty")
347
+
348
+ session_list = gr.Radio(choices=[], value=None, label="Conversations", interactive=True)
349
+ new_btn = gr.Button("New session", variant="secondary")
350
+ del_btn = gr.Button("Delete session", variant="stop")
351
+ dl_btn = gr.Button("Download JSON", variant="secondary")
352
+ dl_file = gr.File(label="", interactive=False, visible=False)
353
+
354
+ with gr.Column(scale=9):
355
+ chat = gr.Chatbot(
356
+ label="Chat",
357
+ height=560,
358
+ render_markdown=True,
359
+ type="messages",
360
+ )
361
+ user_box = gr.Textbox(
362
+ label="Your message",
363
+ placeholder="Type and press Enter…",
364
+ autofocus=True,
365
+ )
366
+ send = gr.Button("Send", variant="primary")
367
+
368
+ # States
369
+ messages = gr.State([]) # includes system
370
+ msg_id = gr.State("")
371
+ sessions = gr.State([])
372
+
373
+ # Events
374
+ user_box.submit(
375
+ on_send,
376
+ inputs=[user_box, messages, msg_id, sessions, sys_prompt, temperature, top_p, max_new_tokens, repetition_penalty],
377
+ outputs=[user_box, messages, chat, msg_id, session_list, sessions],
378
+ )
379
+ send.click(
380
+ on_send,
381
+ inputs=[user_box, messages, msg_id, sessions, sys_prompt, temperature, top_p, max_new_tokens, repetition_penalty],
382
+ outputs=[user_box, messages, chat, msg_id, session_list, sessions],
383
+ )
384
+
385
+ new_btn.click(
386
+ start_new_session,
387
+ inputs=[sessions],
388
+ outputs=[messages, chat, user_box, msg_id, session_list, sessions],
389
+ )
390
+
391
+ del_btn.click(
392
+ delete_session,
393
+ inputs=[msg_id, sessions],
394
+ outputs=[messages, chat, user_box, msg_id, session_list, sessions],
395
+ )
396
+
397
+ session_list.change(
398
+ load_session,
399
+ inputs=[session_list, sessions],
400
+ outputs=[msg_id, messages, chat, session_list],
401
+ )
402
+
403
+ dl_btn.click(
404
+ on_click_download,
405
+ inputs=[messages, msg_id],
406
+ outputs=[dl_file],
407
+ )
408
+
409
+ demo.load(_init_sessions, None, outputs=[session_list, sessions, msg_id, messages, chat])
410
+
411
+ if __name__ == "__main__":
412
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==5.49.1
2
+ huggingface_hub>=0.23
3
+ orjson
4
+ llama-cpp-python==0.2.90
utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+ from pathlib import Path
3
+ import uuid
4
+ from datetime import datetime, timezone
5
+ import json, os
6
+ from typing import List, Dict, Tuple, Optional
7
+
8
+ # ============ 工具函数 ============
9
+ def mk_msg_dir(BASE_MSG_DIR) -> str:
10
+ m_id = datetime.now().strftime("%Y%m%d-%H%M%S-") + uuid.uuid4().hex[:6]
11
+ Path(BASE_MSG_DIR, m_id).mkdir(parents=True, exist_ok=True)
12
+ return m_id # 只返回 ID
13
+
14
+ def _as_dir(BASE_MSG_DIR, m_id: str) -> str:
15
+ # 统一把传入值规整为 ./msgs/<ID>
16
+ return Path(BASE_MSG_DIR, m_id)
17
+
18
+ def msg2hist(persona, msg):
19
+ chat_history = []
20
+ if msg != None:
21
+ if len(msg)>0:
22
+ chat_history = msg.copy() # 外层列表浅拷
23
+ chat_history[0] = msg[0].copy() # 这个字典单独拷
24
+ chat_history[0]['content'] = chat_history[0]['content'][len(persona):]
25
+ return chat_history
26
+
27
+ def render(tok, messages: List[Dict[str, str]]) -> str:
28
+ """按 chat_template 渲染成最终提示词文本(不分词)。"""
29
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
30
+
31
+ def _ensure_alternating(messages):
32
+ if not messages:
33
+ return
34
+ if messages[0]["role"] != "user":
35
+ raise ValueError("messages[0] 必须是 'user'(你的模板要求从 user 开始)")
36
+ for i, m in enumerate(messages):
37
+ expect_user = (i % 2 == 0)
38
+ if (m["role"] == "user") != expect_user:
39
+ raise ValueError(f"对话必须严格交替 user/assistant,在索引 {i} 处发现 {m['role']}")
40
+
41
+ def trim_by_tokens(tok, messages, prompt_budget):
42
+ """
43
+ 只保留 messages[0](persona 的 user)+ 一个“从奇数索引开始的后缀”,
44
+ 用二分法找到能放下的最长后缀。这样可保证交替不被破坏。
45
+ """
46
+ if not messages:
47
+ return []
48
+
49
+ # _ensure_alternating(messages)
50
+
51
+ # 只有 persona 这一条时,直接返回
52
+ if len(messages) == 1:
53
+ return messages
54
+
55
+ # 允许的后缀起点:奇数索引(index 1,3,5,... 都是 assistant),
56
+ # 这样拼接到 index0(user) 后才能保持交替。
57
+ cand_idx = [k for k in range(1, len(messages)) if k % 2 == 1]
58
+
59
+ # 如果任何也放不下,就只留 persona
60
+ best = [messages[0]]
61
+
62
+ # 二分:起点越靠前 → 保留消息越多 → token 越大(单调)
63
+ lo, hi = 0, len(cand_idx) - 1
64
+ while lo <= hi:
65
+ mid = (lo + hi) // 2
66
+ k = cand_idx[mid]
67
+ candidate = [messages[0]] + messages[k:]
68
+ toks = len(tok(tok.apply_chat_template(candidate, tokenize=False),
69
+ add_special_tokens=False).input_ids)
70
+ if toks <= prompt_budget:
71
+ best = candidate # 能放下:尝试保留更多(向左走)
72
+ hi = mid - 1
73
+ else:
74
+ lo = mid + 1 # 放不下:丢更多旧消息(向右走)
75
+
76
+ return best
77
+
78
+ # ============ 原子写 可能会和onedrive同步冲突============
79
+ # def atomic_write_json(path: Path, data) -> None:
80
+ # tmp = path.with_suffix(path.suffix + ".tmp")
81
+ # with open(tmp, "w", encoding="utf-8") as f:
82
+ # json.dump(data, f, ensure_ascii=False, indent=2)
83
+ # f.flush()
84
+ # os.fsync(f.fileno())
85
+ # os.replace(tmp, path) # 同目录原子替换
86
+
87
+ # 直接覆盖
88
+ def write_json_overwrite(path: Path, data) -> None:
89
+ with open(path, "w", encoding="utf-8", newline="\n") as f:
90
+ json.dump(data, f, ensure_ascii=False, indent=2)
91
+
92
+ # ============ 存储层 ============
93
+ class MsgStore:
94
+ def __init__(self, base_dir: str | Path = "./msgs"):
95
+ self.base = Path(base_dir)
96
+ self.base.mkdir(parents=True, exist_ok=True)
97
+ self.archive = self.base / "archive.jsonl" # 只追加
98
+ self.trimmed = self.base / "trimmed.json" # 当前上下文
99
+ if not self.archive.exists():
100
+ self.archive.write_text("", encoding="utf-8")
101
+ if not self.trimmed.exists():
102
+ self.trimmed.write_text("[]", encoding="utf-8")
103
+
104
+ def load_trimmed(self) -> List[Dict[str, str]]:
105
+ try:
106
+ return json.loads(self.trimmed.read_text(encoding="utf-8"))
107
+ except Exception:
108
+ return []
109
+
110
+ def save_trimmed(self, messages: List[Dict[str, str]]) -> None:
111
+ write_json_overwrite(self.trimmed, messages)
112
+
113
+ def append_archive(self, role: str, content: str, meta: dict | None = None) -> None:
114
+ rec = {"ts": datetime.now(timezone.utc).isoformat(), "role": role, "content": content}
115
+ if meta: rec["meta"] = meta
116
+ with open(self.archive, "a", encoding="utf-8") as f:
117
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
118
+ f.flush(); os.fsync(f.fileno())
119
+
120
+ # ============ 显式保存(手动调用才落盘) ============
121
+ def persist_messages(
122
+ messages: List[Dict[str, str]],
123
+ store_dir: str | Path = "./msgs",
124
+ archive_last_turn: bool = True,
125
+ ) -> None:
126
+ store = MsgStore(store_dir)
127
+ # _ensure_alternating(messages)
128
+
129
+ # 1) 覆写 trimmed.json(原子)
130
+ store.save_trimmed(messages)
131
+
132
+ # 2) 追加最近一轮到 archive.jsonl(可选)
133
+ if not archive_last_turn:
134
+ return
135
+
136
+ # 从尾部向前找最近的一对 (user, assistant)
137
+ pair = None
138
+ for i in range(len(messages) - 2, -1, -1):
139
+ if (
140
+ messages[i]["role"] == "user"
141
+ and i + 1 < len(messages)
142
+ and messages[i + 1]["role"] == "assistant"
143
+ ):
144
+ pair = (messages[i]["content"], messages[i + 1]["content"])
145
+ break
146
+
147
+ if pair:
148
+ u, a = pair
149
+ store.append_archive("user", u)
150
+ store.append_archive("assistant", a)
151
+ # 若没有找到成对(比如你在生成前就调用了 persist),就只写 trimmed,不归档