MarshallCN
commited on
Commit
·
bd5ce6f
1
Parent(s):
b943737
init
Browse files- .gitignore +15 -0
- Chat_RAG_vecDB.py +212 -0
- README.md +1 -1
- ggufv2.py +412 -0
- requirements.txt +4 -0
- utils.py +151 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Jupyter
|
| 2 |
+
**/.ipynb_checkpoints/
|
| 3 |
+
.ipynb_* # any hidden Jupyter aux files like .ipynb_foo
|
| 4 |
+
|
| 5 |
+
# Python cache/bytecode
|
| 6 |
+
**/__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
/old/
|
| 10 |
+
/old/*
|
| 11 |
+
models/
|
| 12 |
+
models/*
|
| 13 |
+
export/
|
| 14 |
+
msgs/
|
| 15 |
+
msgs/*
|
Chat_RAG_vecDB.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import stat
|
| 3 |
+
import os, shutil, pickle, torch, json, hashlib
|
| 4 |
+
import faiss, numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from FlagEmbedding import BGEM3FlagModel
|
| 7 |
+
from sentence_transformers import CrossEncoder
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
| 9 |
+
from langchain_community.document_loaders import TextLoader, PyPDFLoader
|
| 10 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 11 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
+
from utils import mk_msg_dir
|
| 13 |
+
|
| 14 |
+
# === 模型加载 ===
|
| 15 |
+
if gr.NO_RELOAD:
|
| 16 |
+
BASE_DIR = r"C:\Users\c1052689\hug_models\Qwen2.5-0.5B-Instruct"
|
| 17 |
+
tok = AutoTokenizer.from_pretrained(BASE_DIR, use_fast=False, local_files_only=True)
|
| 18 |
+
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
| 19 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
| 20 |
+
model = AutoModelForCausalLM.from_pretrained(BASE_DIR, quantization_config=bnb, device_map="auto", local_files_only=True)
|
| 21 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tok, max_new_tokens=512)
|
| 22 |
+
BGEM3 = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
| 23 |
+
reranker = CrossEncoder("BAAI/bge-reranker-large")
|
| 24 |
+
|
| 25 |
+
# === 向量库全局变量 ===
|
| 26 |
+
corpus = []
|
| 27 |
+
index = None
|
| 28 |
+
current_db_dir = None
|
| 29 |
+
|
| 30 |
+
vec_dir_base = './vectorstore/bgem3/'
|
| 31 |
+
embedding_model_id = 'BAAI/bge-m3'
|
| 32 |
+
|
| 33 |
+
# === 文档加载 & 向量构建 ===
|
| 34 |
+
def load_documents(folder: str):
|
| 35 |
+
docs = []
|
| 36 |
+
for path in Path(folder).rglob("*"):
|
| 37 |
+
if path.suffix == ".txt":
|
| 38 |
+
docs += TextLoader(str(path), encoding="utf-8").load()
|
| 39 |
+
elif path.suffix == ".pdf":
|
| 40 |
+
docs += PyPDFLoader(str(path)).load()
|
| 41 |
+
return docs
|
| 42 |
+
|
| 43 |
+
def split_docs(docs, chunk_size=512, chunk_overlap=64):
|
| 44 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 45 |
+
return splitter.split_documents(docs)
|
| 46 |
+
|
| 47 |
+
def on_rm_error(func, path, exc_info):
|
| 48 |
+
os.chmod(path, stat.S_IWRITE)
|
| 49 |
+
func(path)
|
| 50 |
+
|
| 51 |
+
def hash_text(text):
|
| 52 |
+
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
| 53 |
+
|
| 54 |
+
def list_vector_dbs():
|
| 55 |
+
db_list = [f.name for f in Path(vec_dir_base).iterdir() if f.is_dir()]
|
| 56 |
+
return ["<New Vector DB>"] + db_list
|
| 57 |
+
|
| 58 |
+
def create_or_extend_index(docs, selected_db):
|
| 59 |
+
global corpus, index, current_db_dir
|
| 60 |
+
|
| 61 |
+
temp_dir = "temp_docs"
|
| 62 |
+
if os.path.exists(temp_dir):
|
| 63 |
+
shutil.rmtree(temp_dir, onerror=on_rm_error)
|
| 64 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
for file in docs:
|
| 67 |
+
src_path = file.name if hasattr(file, "name") else str(file)
|
| 68 |
+
dst_path = os.path.join(temp_dir, os.path.basename(src_path))
|
| 69 |
+
shutil.copy(src_path, dst_path)
|
| 70 |
+
|
| 71 |
+
raw_docs = load_documents(temp_dir)
|
| 72 |
+
chunks = split_docs(raw_docs)
|
| 73 |
+
new_corpus = [t.page_content for t in chunks]
|
| 74 |
+
new_hashes = [hash_text(t) for t in new_corpus]
|
| 75 |
+
|
| 76 |
+
if selected_db == "<New Vector DB>":
|
| 77 |
+
db_id = mk_msg_dir(Path(vec_dir_base))
|
| 78 |
+
current_db_dir = os.path.join(vec_dir_base, db_id)
|
| 79 |
+
os.makedirs(current_db_dir, exist_ok=True)
|
| 80 |
+
index = faiss.IndexFlatIP(BGEM3.encode(["test"])["dense_vecs"].shape[1])
|
| 81 |
+
corpus = []
|
| 82 |
+
existing_hashes = set()
|
| 83 |
+
else:
|
| 84 |
+
current_db_dir = os.path.join(vec_dir_base, selected_db)
|
| 85 |
+
index = faiss.read_index(os.path.join(current_db_dir, "index.faiss"))
|
| 86 |
+
with open(os.path.join(current_db_dir, "corpus.pkl"), "rb") as f:
|
| 87 |
+
corpus = pickle.load(f)
|
| 88 |
+
with open(os.path.join(current_db_dir, "meta.json"), "r", encoding="utf-8") as f:
|
| 89 |
+
meta = json.load(f)
|
| 90 |
+
existing_hashes = set(meta.get("hashes", []))
|
| 91 |
+
|
| 92 |
+
# 去重
|
| 93 |
+
filtered = [(c, h) for c, h in zip(new_corpus, new_hashes) if h not in existing_hashes]
|
| 94 |
+
if not filtered:
|
| 95 |
+
return "✅ No new (non-duplicate) chunks to add."
|
| 96 |
+
|
| 97 |
+
add_corpus, add_hashes = zip(*filtered)
|
| 98 |
+
dense = BGEM3.encode(add_corpus, batch_size=64)["dense_vecs"]
|
| 99 |
+
if isinstance(dense, torch.Tensor):
|
| 100 |
+
dense = dense.detach().cpu().numpy()
|
| 101 |
+
dense = np.ascontiguousarray(dense, dtype=np.float32)
|
| 102 |
+
faiss.normalize_L2(dense)
|
| 103 |
+
|
| 104 |
+
index.add(dense)
|
| 105 |
+
corpus.extend(add_corpus)
|
| 106 |
+
all_hashes = list(existing_hashes) + list(add_hashes)
|
| 107 |
+
|
| 108 |
+
faiss.write_index(index, os.path.join(current_db_dir, "index.faiss"))
|
| 109 |
+
with open(os.path.join(current_db_dir, "corpus.pkl"), "wb") as f:
|
| 110 |
+
pickle.dump(corpus, f)
|
| 111 |
+
meta = {
|
| 112 |
+
"model": embedding_model_id,
|
| 113 |
+
"dim": int(dense.shape[1]),
|
| 114 |
+
"total_chunks": len(corpus),
|
| 115 |
+
"raw_docs": len(raw_docs),
|
| 116 |
+
"hashes": all_hashes,
|
| 117 |
+
}
|
| 118 |
+
with open(os.path.join(current_db_dir, "meta.json"), "w", encoding="utf-8") as f:
|
| 119 |
+
json.dump(meta, f, indent=2)
|
| 120 |
+
db_stats = f"✅ Added {len(add_corpus)} new chunks to DB `{os.path.basename(current_db_dir)}`."
|
| 121 |
+
db_list_update = gr.update(choices=list_vector_dbs())
|
| 122 |
+
return db_stats, db_list_update
|
| 123 |
+
|
| 124 |
+
def build_prompt_corpus(top_docs, question):
|
| 125 |
+
context_text = "\n\n".join(top_docs)
|
| 126 |
+
user_prompt = f"""Answer the question based on the following context:
|
| 127 |
+
|
| 128 |
+
{context_text}
|
| 129 |
+
|
| 130 |
+
Question: {question}
|
| 131 |
+
Answer:"""
|
| 132 |
+
|
| 133 |
+
full_prompt = (
|
| 134 |
+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
| 135 |
+
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
|
| 136 |
+
"<|im_start|>assistant\n"
|
| 137 |
+
)
|
| 138 |
+
return full_prompt
|
| 139 |
+
|
| 140 |
+
def ask_question(query):
|
| 141 |
+
if not query.strip():
|
| 142 |
+
return "❌ Please enter your questions", [], ""
|
| 143 |
+
if index is None or len(corpus) == 0:
|
| 144 |
+
return "⚠️ Please upload your documents", [], ""
|
| 145 |
+
|
| 146 |
+
qv = np.array(BGEM3.encode([query])["dense_vecs"], dtype="float32")
|
| 147 |
+
faiss.normalize_L2(qv)
|
| 148 |
+
D, I = index.search(qv, 8)
|
| 149 |
+
results = [corpus[i] for i in I[0]]
|
| 150 |
+
pairs = [[query, c] for c in results]
|
| 151 |
+
scores = reranker.predict(pairs)
|
| 152 |
+
top_docs = [c for _, c in sorted(zip(scores, results), reverse=True)][:3]
|
| 153 |
+
|
| 154 |
+
prompt = build_prompt_corpus(top_docs, query)
|
| 155 |
+
out = pipe(
|
| 156 |
+
prompt,
|
| 157 |
+
max_new_tokens=1024,
|
| 158 |
+
eos_token_id=tok.eos_token_id,
|
| 159 |
+
pad_token_id=tok.eos_token_id,
|
| 160 |
+
return_full_text=False,
|
| 161 |
+
)
|
| 162 |
+
reply = out[0]["generated_text"]
|
| 163 |
+
context_display = "\n\n".join(
|
| 164 |
+
f"[{i+1}] {doc.strip()[:1000]}" for i, doc in enumerate(top_docs)
|
| 165 |
+
)
|
| 166 |
+
return reply.strip(), context_display
|
| 167 |
+
|
| 168 |
+
def show_db_stats(selected_db):
|
| 169 |
+
if selected_db == "<New Vector DB>":
|
| 170 |
+
return "🆕 New vector DB will be created on next upload."
|
| 171 |
+
try:
|
| 172 |
+
db_dir = os.path.join(vec_dir_base, selected_db)
|
| 173 |
+
with open(os.path.join(db_dir, "meta.json"), "r", encoding="utf-8") as f:
|
| 174 |
+
meta = json.load(f)
|
| 175 |
+
chunk_num = int(meta.get("total_chunks", 0))
|
| 176 |
+
docs_num = int(meta.get("raw_docs", 0))
|
| 177 |
+
return f"📊 DB `{selected_db}`: {docs_num} docs, {chunk_num} chunks"
|
| 178 |
+
except Exception as e:
|
| 179 |
+
return f"⚠️ Failed to load DB `{selected_db}`: {e}"
|
| 180 |
+
|
| 181 |
+
with gr.Blocks(title="Qwen2.5 RAG Chat") as demo:
|
| 182 |
+
gr.Markdown("## 🧠 Qwen2.5 BGEM3-RAG QA")
|
| 183 |
+
|
| 184 |
+
with gr.Row():
|
| 185 |
+
with gr.Column():
|
| 186 |
+
file_box = gr.File(label="Upload documents (PDF or TXT)", file_types=[".pdf", ".txt"], file_count="multiple")
|
| 187 |
+
db_selector = gr.Dropdown(label="Select or create vector DB", choices=list_vector_dbs(), value="<New Vector DB>")
|
| 188 |
+
upload_btn = gr.Button("📚 Add to Vector DB")
|
| 189 |
+
status = gr.Textbox(label="Status")
|
| 190 |
+
|
| 191 |
+
with gr.Column():
|
| 192 |
+
query = gr.Textbox(label="Enter your questions")
|
| 193 |
+
ask_btn = gr.Button("Send")
|
| 194 |
+
answer = gr.Textbox(label="🧠 Answer", lines=5)
|
| 195 |
+
context = gr.Textbox(
|
| 196 |
+
label="📄 Top-3 Reference Contexts",
|
| 197 |
+
lines=10,
|
| 198 |
+
interactive=False,
|
| 199 |
+
show_copy_button=True,
|
| 200 |
+
max_lines=20
|
| 201 |
+
)
|
| 202 |
+
db_selector.change(fn=show_db_stats, inputs=db_selector, outputs=status)
|
| 203 |
+
upload_btn.click(fn=create_or_extend_index, inputs=[file_box, db_selector], outputs=[status, db_selector])
|
| 204 |
+
ask_btn.click(fn=ask_question, inputs=query, outputs=[answer, context])
|
| 205 |
+
demo.load(
|
| 206 |
+
fn=lambda: gr.update(choices=list_vector_dbs()),
|
| 207 |
+
inputs=None,
|
| 208 |
+
outputs=db_selector
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if __name__ == '__main__':
|
| 212 |
+
demo.launch(debug=True)
|
README.md
CHANGED
|
@@ -5,7 +5,7 @@ colorFrom: indigo
|
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: Qwen2.5-0.5B-Q4 RAG demo
|
|
|
|
| 5 |
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
+
app_file: ggufv2.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: Qwen2.5-0.5B-Q4 RAG demo
|
ggufv2.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# gguf.py — Qwen GGUF chat with multi-session (load/save) via utils.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import List, Dict, Optional, Tuple
|
| 7 |
+
import shutil
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from llama_cpp import Llama
|
| 10 |
+
|
| 11 |
+
# Multi-session helpers from utils.py
|
| 12 |
+
from utils import mk_msg_dir, _as_dir, persist_messages, trim_by_tokens
|
| 13 |
+
# ===================== Model =====================
|
| 14 |
+
# You can swap to another GGUF by changing repo_id/filename.
|
| 15 |
+
model = Llama.from_pretrained(
|
| 16 |
+
repo_id="bartowski/Qwen2.5-0.5B-Instruct-GGUF",
|
| 17 |
+
filename="Qwen2.5-0.5B-Instruct-Q4_K_M.gguf",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
assistant_name = "Nova"
|
| 21 |
+
user_name = "Marshall"
|
| 22 |
+
persona = f"""Your name is {assistant_name}. Address the user as "{user_name}". Use Markdown; put code in fenced blocks with a language tag.""".strip()
|
| 23 |
+
|
| 24 |
+
# Where each conversation (session) persists its messages
|
| 25 |
+
BASE_MSG_DIR = Path("./msgs/msgs_QwenGGUF")
|
| 26 |
+
BASE_MSG_DIR.mkdir(parents=True, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# ---------- Qwen chat template (no tools) ----------
|
| 29 |
+
# def render_qwen(messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str:
|
| 30 |
+
# """
|
| 31 |
+
# Convert OpenAI-style messages to Qwen2.5 Instruct format:
|
| 32 |
+
# <|im_start|>system ... <|im_end|>
|
| 33 |
+
# <|im_start|>user ... <|im_end|>
|
| 34 |
+
# <|im_start|>assistant (generation continues here)
|
| 35 |
+
# """
|
| 36 |
+
# # System prompt
|
| 37 |
+
# if messages and messages[0].get("role") == "system":
|
| 38 |
+
# sys_txt = messages[0]["content"]
|
| 39 |
+
# rest = messages[1:]
|
| 40 |
+
# else:
|
| 41 |
+
# sys_txt = persona
|
| 42 |
+
# rest = messages
|
| 43 |
+
|
| 44 |
+
# parts = [f"<|im_start|>system\n{sys_txt}<|im_end|>\n"]
|
| 45 |
+
# for m in rest:
|
| 46 |
+
# role = m.get("role")
|
| 47 |
+
# if role not in ("user", "assistant"):
|
| 48 |
+
# continue
|
| 49 |
+
# parts.append(f"<|im_start|>{role}\n{m['content']}<|im_end|>\n")
|
| 50 |
+
|
| 51 |
+
# if add_generation_prompt:
|
| 52 |
+
# parts.append("<|im_start|>assistant\n")
|
| 53 |
+
# return "".join(parts)
|
| 54 |
+
|
| 55 |
+
def render_qwen_trim(
|
| 56 |
+
messages: List[Dict[str, str]],
|
| 57 |
+
model, # llama_cpp.Llama 实例(用于 token 计数)
|
| 58 |
+
n_ctx: Optional[int] = None, # 不传则用 model.n_ctx()
|
| 59 |
+
add_generation_prompt: bool = True,
|
| 60 |
+
persona: str = "",
|
| 61 |
+
reserve_new: int = 256, # 希望生成的新 token 预算(上限)
|
| 62 |
+
pad: int = 8, # 保险余量,避免越界
|
| 63 |
+
hard_user_tail_chars: int = 2000, # 还不够时,最后一条 user 文本的硬截断字符数
|
| 64 |
+
) -> Tuple[str, int]:
|
| 65 |
+
"""
|
| 66 |
+
- 只保留 system + 最近的若干轮对话,使得 total_tokens + reserve_new + pad <= n_ctx
|
| 67 |
+
- 若仍不够,则截短最后一条 user。
|
| 68 |
+
- 返回 (prompt, safe_max_new),safe_max_new 已确保不越界。
|
| 69 |
+
"""
|
| 70 |
+
def _tok_len(txt: str) -> int:
|
| 71 |
+
# 与 llama_cpp 的计数保持一致
|
| 72 |
+
return len(model.tokenize(txt.encode("utf-8"), add_bos=True))
|
| 73 |
+
|
| 74 |
+
if n_ctx is None:
|
| 75 |
+
n_ctx = getattr(model, "n_ctx")() if callable(getattr(model, "n_ctx", None)) else model.n_ctx
|
| 76 |
+
|
| 77 |
+
# 1) 拆出 system 与其余消息
|
| 78 |
+
if messages and messages[0].get("role") == "system":
|
| 79 |
+
sys_txt = messages[0]["content"]
|
| 80 |
+
rest = messages[1:]
|
| 81 |
+
else:
|
| 82 |
+
sys_txt = persona
|
| 83 |
+
rest = messages
|
| 84 |
+
|
| 85 |
+
# 仅保留 user / assistant
|
| 86 |
+
rest = [m for m in rest if m.get("role") in ("user", "assistant")]
|
| 87 |
+
|
| 88 |
+
# 2) 生成函数:把 system + 若干轮对话渲染为 Qwen prompt
|
| 89 |
+
def _render(sys_text: str, turns: List[Dict[str, str]], add_gen: bool) -> str:
|
| 90 |
+
parts = [f"<|im_start|>system\n{sys_text}<|im_end|>\n"]
|
| 91 |
+
for m in turns:
|
| 92 |
+
parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n")
|
| 93 |
+
if add_gen:
|
| 94 |
+
parts.append("<|im_start|>assistant\n")
|
| 95 |
+
return "".join(parts)
|
| 96 |
+
|
| 97 |
+
# 3) 先尝试保留全部轮次,从最老开始裁剪直到 fits
|
| 98 |
+
kept = rest[:] # 深拷贝
|
| 99 |
+
while True:
|
| 100 |
+
prompt = _render(sys_txt, kept, add_generation_prompt)
|
| 101 |
+
used = _tok_len(prompt)
|
| 102 |
+
|
| 103 |
+
# 计算还能安全生成的 token 数
|
| 104 |
+
safe_max_new = max(1, n_ctx - used - pad)
|
| 105 |
+
# 希望生成 reserve_new,但不能超过 safe_max_new
|
| 106 |
+
if used + reserve_new + pad <= n_ctx:
|
| 107 |
+
# 有余量,按 reserve_new 返回可生成上限
|
| 108 |
+
return prompt, min(reserve_new, safe_max_new)
|
| 109 |
+
|
| 110 |
+
# 没有余量——需要裁剪历史。如果可裁剪的 turns < 1,则进入硬截断
|
| 111 |
+
if len(kept) <= 1:
|
| 112 |
+
break # 只剩最后一条,跳出去做硬截断
|
| 113 |
+
|
| 114 |
+
# 从最早的一条开始丢;为避免打断成对语义,可一次丢两条(user+assistant)
|
| 115 |
+
# 但如果开头不是成对,就按 1 条丢弃。
|
| 116 |
+
drop_count = 2 if len(kept) >= 2 else 1
|
| 117 |
+
# 保证留下至少 1 条(最后一条 user)用于上下文
|
| 118 |
+
while drop_count > 0 and len(kept) > 1:
|
| 119 |
+
kept.pop(0)
|
| 120 |
+
drop_count -= 1
|
| 121 |
+
|
| 122 |
+
# 4) 仍然不够:硬截断“最后一条 user”文本尾部
|
| 123 |
+
# 目标:尽量保留最近语义,同时立刻释放 token 空间
|
| 124 |
+
if kept and kept[-1]["role"] == "user":
|
| 125 |
+
kept[-1] = {
|
| 126 |
+
"role": "user",
|
| 127 |
+
"content": kept[-1]["content"][-hard_user_tail_chars:]
|
| 128 |
+
}
|
| 129 |
+
elif kept:
|
| 130 |
+
# 最后一条不是 user,则尽量截短它(通常是 assistant)
|
| 131 |
+
kept[-1] = {
|
| 132 |
+
"role": kept[-1]["role"],
|
| 133 |
+
"content": kept[-1]["content"][-hard_user_tail_chars:]
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# 重新渲染并最终给出安全 max_new
|
| 137 |
+
prompt = _render(sys_txt, kept, add_generation_prompt)
|
| 138 |
+
used = _tok_len(prompt)
|
| 139 |
+
safe_max_new = max(1, n_ctx - used - pad)
|
| 140 |
+
|
| 141 |
+
# 如果仍然超(极端长的 system),进一步把 system 也截短
|
| 142 |
+
if used + pad > n_ctx:
|
| 143 |
+
trimmed_sys = sys_txt[-hard_user_tail_chars:]
|
| 144 |
+
prompt = _render(trimmed_sys, kept, add_generation_prompt)
|
| 145 |
+
used = _tok_len(prompt)
|
| 146 |
+
safe_max_new = max(1, n_ctx - used - pad)
|
| 147 |
+
|
| 148 |
+
# 不允许返回负或 0
|
| 149 |
+
return prompt, max(1, safe_max_new)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
STOP_TOKENS = ["<|im_end|>", "<|endoftext|>"]
|
| 153 |
+
|
| 154 |
+
# ---------- Helpers for system + display ----------
|
| 155 |
+
def ensure_system(messages: Optional[List[Dict[str, str]]], sys_prompt: str) -> List[Dict[str, str]]:
|
| 156 |
+
"""Guarantee a system message at index 0 and keep it in sync with the UI textbox."""
|
| 157 |
+
sys_prompt = (sys_prompt or persona).strip()
|
| 158 |
+
if not messages or messages[0].get("role") != "system":
|
| 159 |
+
return [{"role": "system", "content": sys_prompt}]
|
| 160 |
+
messages = list(messages)
|
| 161 |
+
messages[0] = {"role": "system", "content": sys_prompt}
|
| 162 |
+
return messages
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def visible_chat(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
| 166 |
+
"""Hide system from chat display for gr.Chatbot(type='messages')."""
|
| 167 |
+
return [m for m in (messages or []) if m.get("role") in ("user", "assistant")]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------- Session I/O ----------
|
| 171 |
+
def _load_latest(msg_id: str) -> List[Dict[str, str]]:
|
| 172 |
+
p = Path(_as_dir(BASE_MSG_DIR, msg_id), "trimmed.json")
|
| 173 |
+
if p.exists():
|
| 174 |
+
try:
|
| 175 |
+
return json.loads(p.read_text(encoding="utf-8"))
|
| 176 |
+
except Exception:
|
| 177 |
+
return []
|
| 178 |
+
return []
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _init_sessions():
|
| 182 |
+
sessions = [p.name for p in BASE_MSG_DIR.iterdir() if p.is_dir()] if BASE_MSG_DIR.exists() else []
|
| 183 |
+
if len(sessions) == 0:
|
| 184 |
+
# No history
|
| 185 |
+
return gr.update(choices=[], value=None), [], "", [], []
|
| 186 |
+
sessions.sort(reverse=True)
|
| 187 |
+
msg_id = sessions[0]
|
| 188 |
+
messages = _load_latest(msg_id)
|
| 189 |
+
chat_hist = visible_chat(messages)
|
| 190 |
+
return gr.update(choices=sessions, value=msg_id), sessions, msg_id, messages, chat_hist
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def load_session(session_list, sessions):
|
| 194 |
+
msg_id = session_list
|
| 195 |
+
messages = _load_latest(msg_id)
|
| 196 |
+
chat_hist = visible_chat(messages)
|
| 197 |
+
return msg_id, messages, chat_hist, gr.update(choices=sessions, value=msg_id)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def start_new_session(sessions):
|
| 201 |
+
msg_id = mk_msg_dir(BASE_MSG_DIR)
|
| 202 |
+
sessions = list(sessions or []) + [msg_id]
|
| 203 |
+
return [], [], "", msg_id, gr.update(choices=sessions, value=msg_id), sessions
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _on_rm_error(func, path, exc_info):
|
| 207 |
+
try:
|
| 208 |
+
if os.name == "nt": # Windows
|
| 209 |
+
os.chmod(path, stat.S_IWRITE) # 去掉只读
|
| 210 |
+
else: # Linux / macOS
|
| 211 |
+
mode = os.stat(path).st_mode
|
| 212 |
+
os.chmod(
|
| 213 |
+
path,
|
| 214 |
+
mode | stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR # 给所有者加 r,w,x
|
| 215 |
+
)
|
| 216 |
+
func(path) # 重试原操作,如 os.remove 或 os.rmdir
|
| 217 |
+
except Exception:
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
def delete_session(msg_id, sessions):
|
| 221 |
+
"""Delete the currently selected session directory and refresh the list."""
|
| 222 |
+
# Remove directory for current session
|
| 223 |
+
if msg_id:
|
| 224 |
+
try:
|
| 225 |
+
shutil.rmtree(_as_dir(BASE_MSG_DIR, msg_id), onerror=_on_rm_error)
|
| 226 |
+
except Exception:
|
| 227 |
+
shutil.rmtree(_as_dir(BASE_MSG_DIR, msg_id), ignore_errors=True)
|
| 228 |
+
# Re-scan sessions on disk
|
| 229 |
+
if BASE_MSG_DIR.exists():
|
| 230 |
+
sess = [p.name for p in BASE_MSG_DIR.iterdir() if p.is_dir()]
|
| 231 |
+
else:
|
| 232 |
+
sess = []
|
| 233 |
+
sess.sort(reverse=True)
|
| 234 |
+
if sess:
|
| 235 |
+
new_id = sess[0]
|
| 236 |
+
msgs = _load_latest(new_id)
|
| 237 |
+
chat_hist = visible_chat(msgs)
|
| 238 |
+
return msgs, chat_hist, "", new_id, gr.update(choices=sess, value=new_id), sess
|
| 239 |
+
else:
|
| 240 |
+
return [], [], "", "", gr.update(choices=[], value=None), []
|
| 241 |
+
|
| 242 |
+
def export_messages_to_json(messages, msg_id):
|
| 243 |
+
base = Path("/data/exports") if Path("/data").exists() else Path("./exports")
|
| 244 |
+
base.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
stamp = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
| 246 |
+
fname = f"chat_{stamp}.json"
|
| 247 |
+
path = base / fname
|
| 248 |
+
path.write_text(json.dumps(messages or [], ensure_ascii=False, indent=2), encoding="utf-8")
|
| 249 |
+
return str(path)
|
| 250 |
+
|
| 251 |
+
def on_click_download(messages, msg_id):
|
| 252 |
+
path = export_messages_to_json(messages, msg_id)
|
| 253 |
+
return gr.update(value=path, visible=True)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ---------- Generation callback ----------
|
| 257 |
+
def on_send(user_text: str,
|
| 258 |
+
messages: List[Dict[str, str]],
|
| 259 |
+
msg_id: str,
|
| 260 |
+
sessions: List[str],
|
| 261 |
+
sys_prompt: str,
|
| 262 |
+
temperature: float,
|
| 263 |
+
top_p: float,
|
| 264 |
+
max_new_tokens: int,
|
| 265 |
+
repetition_penalty: float):
|
| 266 |
+
user_text = (user_text or "").strip()
|
| 267 |
+
if not user_text:
|
| 268 |
+
return gr.update(), messages, visible_chat(messages), msg_id, gr.update(choices=sessions, value=(msg_id or None)), sessions
|
| 269 |
+
|
| 270 |
+
# 1) ensure system
|
| 271 |
+
messages = ensure_system(messages, sys_prompt)
|
| 272 |
+
|
| 273 |
+
# 2) session bookkeeping
|
| 274 |
+
new_session = (len(messages) <= 1) # only system exists
|
| 275 |
+
if new_session and not msg_id:
|
| 276 |
+
msg_id = mk_msg_dir(BASE_MSG_DIR)
|
| 277 |
+
sessions = list(sessions or []) + [msg_id]
|
| 278 |
+
if msg_id and msg_id not in (sessions or []):
|
| 279 |
+
sessions = list(sessions or []) + [msg_id]
|
| 280 |
+
sessions_update = gr.update(choices=sessions, value=msg_id)
|
| 281 |
+
|
| 282 |
+
# 3) append user, render, generate
|
| 283 |
+
messages = messages + [{"role": "user", "content": user_text}]
|
| 284 |
+
# prompt = render_qwen(messages, add_generation_prompt=True)
|
| 285 |
+
prompt, max_new = render_qwen_trim(
|
| 286 |
+
messages=messages,
|
| 287 |
+
model=model, # llama_cpp.Llama 实例
|
| 288 |
+
n_ctx=None, # 不传用 model.n_ctx()
|
| 289 |
+
add_generation_prompt=True,
|
| 290 |
+
persona=persona, # 你之前的 persona 变量
|
| 291 |
+
reserve_new=max_new_tokens, # 你希望的生成长度
|
| 292 |
+
pad=16
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
result = model.create_completion(
|
| 298 |
+
prompt=prompt,
|
| 299 |
+
temperature=float(temperature),
|
| 300 |
+
top_p=float(top_p),
|
| 301 |
+
max_tokens=int(max_new),
|
| 302 |
+
repeat_penalty=float(repetition_penalty),
|
| 303 |
+
stop=STOP_TOKENS,
|
| 304 |
+
)
|
| 305 |
+
reply = result['choices'][0]['text'].strip()
|
| 306 |
+
except Exception:
|
| 307 |
+
_out = model(
|
| 308 |
+
prompt,
|
| 309 |
+
temperature=float(temperature),
|
| 310 |
+
top_p=float(top_p),
|
| 311 |
+
max_tokens=int(max_new),
|
| 312 |
+
repeat_penalty=float(repetition_penalty),
|
| 313 |
+
stop=STOP_TOKENS,
|
| 314 |
+
)
|
| 315 |
+
if isinstance(_out, dict):
|
| 316 |
+
reply = _out.get('choices', [{}])[0].get('text', '').strip()
|
| 317 |
+
else:
|
| 318 |
+
reply = str(_out).strip()
|
| 319 |
+
|
| 320 |
+
# 4) append assistant + persist
|
| 321 |
+
messages = messages + [{"role": "assistant", "content": reply}]
|
| 322 |
+
|
| 323 |
+
if msg_id:
|
| 324 |
+
msg_dir = _as_dir(BASE_MSG_DIR, msg_id)
|
| 325 |
+
persist_messages(messages, msg_dir, archive_last_turn=True)
|
| 326 |
+
|
| 327 |
+
return "", messages, visible_chat(messages), msg_id, sessions_update, sessions
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# ===================== UI =====================
|
| 331 |
+
with gr.Blocks(title="Qwen GGUF — multi-session") as demo:
|
| 332 |
+
gr.Markdown("## 🧠 Qwen Chat")
|
| 333 |
+
|
| 334 |
+
with gr.Row():
|
| 335 |
+
with gr.Column(scale=3):
|
| 336 |
+
sys_prompt = gr.Textbox(
|
| 337 |
+
label="System prompt",
|
| 338 |
+
value=persona,
|
| 339 |
+
lines=6,
|
| 340 |
+
show_label=True,
|
| 341 |
+
)
|
| 342 |
+
with gr.Accordion("Generation settings", open=False):
|
| 343 |
+
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
|
| 344 |
+
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
|
| 345 |
+
max_new_tokens = gr.Slider(16, 512, value=256, step=16, label="max_new_tokens")
|
| 346 |
+
repetition_penalty = gr.Slider(1.0, 2.0, value=1.07, step=0.01, label="repetition_penalty")
|
| 347 |
+
|
| 348 |
+
session_list = gr.Radio(choices=[], value=None, label="Conversations", interactive=True)
|
| 349 |
+
new_btn = gr.Button("New session", variant="secondary")
|
| 350 |
+
del_btn = gr.Button("Delete session", variant="stop")
|
| 351 |
+
dl_btn = gr.Button("Download JSON", variant="secondary")
|
| 352 |
+
dl_file = gr.File(label="", interactive=False, visible=False)
|
| 353 |
+
|
| 354 |
+
with gr.Column(scale=9):
|
| 355 |
+
chat = gr.Chatbot(
|
| 356 |
+
label="Chat",
|
| 357 |
+
height=560,
|
| 358 |
+
render_markdown=True,
|
| 359 |
+
type="messages",
|
| 360 |
+
)
|
| 361 |
+
user_box = gr.Textbox(
|
| 362 |
+
label="Your message",
|
| 363 |
+
placeholder="Type and press Enter…",
|
| 364 |
+
autofocus=True,
|
| 365 |
+
)
|
| 366 |
+
send = gr.Button("Send", variant="primary")
|
| 367 |
+
|
| 368 |
+
# States
|
| 369 |
+
messages = gr.State([]) # includes system
|
| 370 |
+
msg_id = gr.State("")
|
| 371 |
+
sessions = gr.State([])
|
| 372 |
+
|
| 373 |
+
# Events
|
| 374 |
+
user_box.submit(
|
| 375 |
+
on_send,
|
| 376 |
+
inputs=[user_box, messages, msg_id, sessions, sys_prompt, temperature, top_p, max_new_tokens, repetition_penalty],
|
| 377 |
+
outputs=[user_box, messages, chat, msg_id, session_list, sessions],
|
| 378 |
+
)
|
| 379 |
+
send.click(
|
| 380 |
+
on_send,
|
| 381 |
+
inputs=[user_box, messages, msg_id, sessions, sys_prompt, temperature, top_p, max_new_tokens, repetition_penalty],
|
| 382 |
+
outputs=[user_box, messages, chat, msg_id, session_list, sessions],
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
new_btn.click(
|
| 386 |
+
start_new_session,
|
| 387 |
+
inputs=[sessions],
|
| 388 |
+
outputs=[messages, chat, user_box, msg_id, session_list, sessions],
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
del_btn.click(
|
| 392 |
+
delete_session,
|
| 393 |
+
inputs=[msg_id, sessions],
|
| 394 |
+
outputs=[messages, chat, user_box, msg_id, session_list, sessions],
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
session_list.change(
|
| 398 |
+
load_session,
|
| 399 |
+
inputs=[session_list, sessions],
|
| 400 |
+
outputs=[msg_id, messages, chat, session_list],
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
dl_btn.click(
|
| 404 |
+
on_click_download,
|
| 405 |
+
inputs=[messages, msg_id],
|
| 406 |
+
outputs=[dl_file],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
demo.load(_init_sessions, None, outputs=[session_list, sessions, msg_id, messages, chat])
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
demo.queue().launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.49.1
|
| 2 |
+
huggingface_hub>=0.23
|
| 3 |
+
orjson
|
| 4 |
+
llama-cpp-python==0.2.90
|
utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from __future__ import annotations
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import uuid
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
import json, os
|
| 6 |
+
from typing import List, Dict, Tuple, Optional
|
| 7 |
+
|
| 8 |
+
# ============ 工具函数 ============
|
| 9 |
+
def mk_msg_dir(BASE_MSG_DIR) -> str:
|
| 10 |
+
m_id = datetime.now().strftime("%Y%m%d-%H%M%S-") + uuid.uuid4().hex[:6]
|
| 11 |
+
Path(BASE_MSG_DIR, m_id).mkdir(parents=True, exist_ok=True)
|
| 12 |
+
return m_id # 只返回 ID
|
| 13 |
+
|
| 14 |
+
def _as_dir(BASE_MSG_DIR, m_id: str) -> str:
|
| 15 |
+
# 统一把传入值规整为 ./msgs/<ID>
|
| 16 |
+
return Path(BASE_MSG_DIR, m_id)
|
| 17 |
+
|
| 18 |
+
def msg2hist(persona, msg):
|
| 19 |
+
chat_history = []
|
| 20 |
+
if msg != None:
|
| 21 |
+
if len(msg)>0:
|
| 22 |
+
chat_history = msg.copy() # 外层列表浅拷
|
| 23 |
+
chat_history[0] = msg[0].copy() # 这个字典单独拷
|
| 24 |
+
chat_history[0]['content'] = chat_history[0]['content'][len(persona):]
|
| 25 |
+
return chat_history
|
| 26 |
+
|
| 27 |
+
def render(tok, messages: List[Dict[str, str]]) -> str:
|
| 28 |
+
"""按 chat_template 渲染成最终提示词文本(不分词)。"""
|
| 29 |
+
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 30 |
+
|
| 31 |
+
def _ensure_alternating(messages):
|
| 32 |
+
if not messages:
|
| 33 |
+
return
|
| 34 |
+
if messages[0]["role"] != "user":
|
| 35 |
+
raise ValueError("messages[0] 必须是 'user'(你的模板要求从 user 开始)")
|
| 36 |
+
for i, m in enumerate(messages):
|
| 37 |
+
expect_user = (i % 2 == 0)
|
| 38 |
+
if (m["role"] == "user") != expect_user:
|
| 39 |
+
raise ValueError(f"对话必须严格交替 user/assistant,在索引 {i} 处发现 {m['role']}")
|
| 40 |
+
|
| 41 |
+
def trim_by_tokens(tok, messages, prompt_budget):
|
| 42 |
+
"""
|
| 43 |
+
只保留 messages[0](persona 的 user)+ 一个“从奇数索引开始的后缀”,
|
| 44 |
+
用二分法找到能放下的最长后缀。这样可保证交替不被破坏。
|
| 45 |
+
"""
|
| 46 |
+
if not messages:
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
# _ensure_alternating(messages)
|
| 50 |
+
|
| 51 |
+
# 只有 persona 这一条时,直接返回
|
| 52 |
+
if len(messages) == 1:
|
| 53 |
+
return messages
|
| 54 |
+
|
| 55 |
+
# 允许的后缀起点:奇数索引(index 1,3,5,... 都是 assistant),
|
| 56 |
+
# 这样拼接到 index0(user) 后才能保持交替。
|
| 57 |
+
cand_idx = [k for k in range(1, len(messages)) if k % 2 == 1]
|
| 58 |
+
|
| 59 |
+
# 如果任何也放不下,就只留 persona
|
| 60 |
+
best = [messages[0]]
|
| 61 |
+
|
| 62 |
+
# 二分:起点越靠前 → 保留消息越多 → token 越大(单调)
|
| 63 |
+
lo, hi = 0, len(cand_idx) - 1
|
| 64 |
+
while lo <= hi:
|
| 65 |
+
mid = (lo + hi) // 2
|
| 66 |
+
k = cand_idx[mid]
|
| 67 |
+
candidate = [messages[0]] + messages[k:]
|
| 68 |
+
toks = len(tok(tok.apply_chat_template(candidate, tokenize=False),
|
| 69 |
+
add_special_tokens=False).input_ids)
|
| 70 |
+
if toks <= prompt_budget:
|
| 71 |
+
best = candidate # 能放下:尝试保留更多(向左走)
|
| 72 |
+
hi = mid - 1
|
| 73 |
+
else:
|
| 74 |
+
lo = mid + 1 # 放不下:丢更多旧消息(向右走)
|
| 75 |
+
|
| 76 |
+
return best
|
| 77 |
+
|
| 78 |
+
# ============ 原子写 可能会和onedrive同步冲突============
|
| 79 |
+
# def atomic_write_json(path: Path, data) -> None:
|
| 80 |
+
# tmp = path.with_suffix(path.suffix + ".tmp")
|
| 81 |
+
# with open(tmp, "w", encoding="utf-8") as f:
|
| 82 |
+
# json.dump(data, f, ensure_ascii=False, indent=2)
|
| 83 |
+
# f.flush()
|
| 84 |
+
# os.fsync(f.fileno())
|
| 85 |
+
# os.replace(tmp, path) # 同目录原子替换
|
| 86 |
+
|
| 87 |
+
# 直接覆盖
|
| 88 |
+
def write_json_overwrite(path: Path, data) -> None:
|
| 89 |
+
with open(path, "w", encoding="utf-8", newline="\n") as f:
|
| 90 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 91 |
+
|
| 92 |
+
# ============ 存储层 ============
|
| 93 |
+
class MsgStore:
|
| 94 |
+
def __init__(self, base_dir: str | Path = "./msgs"):
|
| 95 |
+
self.base = Path(base_dir)
|
| 96 |
+
self.base.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
self.archive = self.base / "archive.jsonl" # 只追加
|
| 98 |
+
self.trimmed = self.base / "trimmed.json" # 当前上下文
|
| 99 |
+
if not self.archive.exists():
|
| 100 |
+
self.archive.write_text("", encoding="utf-8")
|
| 101 |
+
if not self.trimmed.exists():
|
| 102 |
+
self.trimmed.write_text("[]", encoding="utf-8")
|
| 103 |
+
|
| 104 |
+
def load_trimmed(self) -> List[Dict[str, str]]:
|
| 105 |
+
try:
|
| 106 |
+
return json.loads(self.trimmed.read_text(encoding="utf-8"))
|
| 107 |
+
except Exception:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
def save_trimmed(self, messages: List[Dict[str, str]]) -> None:
|
| 111 |
+
write_json_overwrite(self.trimmed, messages)
|
| 112 |
+
|
| 113 |
+
def append_archive(self, role: str, content: str, meta: dict | None = None) -> None:
|
| 114 |
+
rec = {"ts": datetime.now(timezone.utc).isoformat(), "role": role, "content": content}
|
| 115 |
+
if meta: rec["meta"] = meta
|
| 116 |
+
with open(self.archive, "a", encoding="utf-8") as f:
|
| 117 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 118 |
+
f.flush(); os.fsync(f.fileno())
|
| 119 |
+
|
| 120 |
+
# ============ 显式保存(手动调用才落盘) ============
|
| 121 |
+
def persist_messages(
|
| 122 |
+
messages: List[Dict[str, str]],
|
| 123 |
+
store_dir: str | Path = "./msgs",
|
| 124 |
+
archive_last_turn: bool = True,
|
| 125 |
+
) -> None:
|
| 126 |
+
store = MsgStore(store_dir)
|
| 127 |
+
# _ensure_alternating(messages)
|
| 128 |
+
|
| 129 |
+
# 1) 覆写 trimmed.json(原子)
|
| 130 |
+
store.save_trimmed(messages)
|
| 131 |
+
|
| 132 |
+
# 2) 追加最近一轮到 archive.jsonl(可选)
|
| 133 |
+
if not archive_last_turn:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
# 从尾部向前找最近的一对 (user, assistant)
|
| 137 |
+
pair = None
|
| 138 |
+
for i in range(len(messages) - 2, -1, -1):
|
| 139 |
+
if (
|
| 140 |
+
messages[i]["role"] == "user"
|
| 141 |
+
and i + 1 < len(messages)
|
| 142 |
+
and messages[i + 1]["role"] == "assistant"
|
| 143 |
+
):
|
| 144 |
+
pair = (messages[i]["content"], messages[i + 1]["content"])
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
if pair:
|
| 148 |
+
u, a = pair
|
| 149 |
+
store.append_archive("user", u)
|
| 150 |
+
store.append_archive("assistant", a)
|
| 151 |
+
# 若没有找到成对(比如你在生成前就调用了 persist),就只写 trimmed,不归档
|