Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
d708c00
1
Parent(s):
86018c8
支持ChatGLM
Browse files- .gitignore +3 -0
- modules/models.py +53 -0
- modules/presets.py +3 -0
- requirements.txt +2 -0
.gitignore
CHANGED
|
@@ -133,7 +133,10 @@ dmypy.json
|
|
| 133 |
# Mac system file
|
| 134 |
**/.DS_Store
|
| 135 |
|
|
|
|
| 136 |
api_key.txt
|
| 137 |
config.json
|
| 138 |
auth.json
|
|
|
|
|
|
|
| 139 |
.idea
|
|
|
|
| 133 |
# Mac system file
|
| 134 |
**/.DS_Store
|
| 135 |
|
| 136 |
+
# 配置文件/模型文件
|
| 137 |
api_key.txt
|
| 138 |
config.json
|
| 139 |
auth.json
|
| 140 |
+
models/
|
| 141 |
+
lora/
|
| 142 |
.idea
|
modules/models.py
CHANGED
|
@@ -8,6 +8,7 @@ import os
|
|
| 8 |
import sys
|
| 9 |
import requests
|
| 10 |
import urllib3
|
|
|
|
| 11 |
|
| 12 |
from tqdm import tqdm
|
| 13 |
import colorama
|
|
@@ -191,6 +192,55 @@ class OpenAIClient(BaseLLMModel):
|
|
| 191 |
# logging.error(f"Error: {e}")
|
| 192 |
continue
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
def get_model(
|
| 196 |
model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
|
|
@@ -198,6 +248,7 @@ def get_model(
|
|
| 198 |
msg = f"模型设置为了: {model_name}"
|
| 199 |
logging.info(msg)
|
| 200 |
model_type = ModelType.get_type(model_name)
|
|
|
|
| 201 |
if model_type == ModelType.OpenAI:
|
| 202 |
model = OpenAIClient(
|
| 203 |
model_name=model_name,
|
|
@@ -206,6 +257,8 @@ def get_model(
|
|
| 206 |
temperature=temperature,
|
| 207 |
top_p=top_p,
|
| 208 |
)
|
|
|
|
|
|
|
| 209 |
return model, msg
|
| 210 |
|
| 211 |
|
|
|
|
| 8 |
import sys
|
| 9 |
import requests
|
| 10 |
import urllib3
|
| 11 |
+
import platform
|
| 12 |
|
| 13 |
from tqdm import tqdm
|
| 14 |
import colorama
|
|
|
|
| 192 |
# logging.error(f"Error: {e}")
|
| 193 |
continue
|
| 194 |
|
| 195 |
+
class ChatGLM_Client(BaseLLMModel):
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
model_name,
|
| 199 |
+
model_path = None
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__(
|
| 202 |
+
model_name=model_name
|
| 203 |
+
)
|
| 204 |
+
from transformers import AutoTokenizer, AutoModel
|
| 205 |
+
import torch
|
| 206 |
+
system_name = platform.system()
|
| 207 |
+
if os.path.exists("models"):
|
| 208 |
+
model_dirs = os.listdir("models")
|
| 209 |
+
if model_name in model_dirs:
|
| 210 |
+
model_path = f"models/{model_name}"
|
| 211 |
+
if model_path is not None:
|
| 212 |
+
model_source = model_path
|
| 213 |
+
else:
|
| 214 |
+
model_source = f"THUDM/{model_name}"
|
| 215 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True)
|
| 216 |
+
if torch.cuda.is_available():
|
| 217 |
+
# run on CUDA
|
| 218 |
+
model = AutoModel.from_pretrained(model_source, trust_remote_code=True).half().cuda()
|
| 219 |
+
elif system_name == "Darwin" and model_path is not None:
|
| 220 |
+
# running on macOS and model already downloaded
|
| 221 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().to('mps')
|
| 222 |
+
else:
|
| 223 |
+
# run on CPU
|
| 224 |
+
model = AutoModel.from_pretrained(model_source, trust_remote_code=True).float()
|
| 225 |
+
model = model.eval()
|
| 226 |
+
self.model = model
|
| 227 |
+
|
| 228 |
+
def _get_glm_style_input(self):
|
| 229 |
+
history = [x["content"] for x in self.history]
|
| 230 |
+
query = history.pop()
|
| 231 |
+
return history, query
|
| 232 |
+
|
| 233 |
+
def get_answer_at_once(self):
|
| 234 |
+
history, query = self._get_glm_style_input()
|
| 235 |
+
response, _ = self.model.chat(self.tokenizer, query, history=history)
|
| 236 |
+
return response
|
| 237 |
+
|
| 238 |
+
def get_answer_stream_iter(self):
|
| 239 |
+
history, query = self._get_glm_style_input()
|
| 240 |
+
for response, history in self.model.stream_chat(self.tokenizer, query, history, max_length=self.token_upper_limit, top_p=self.top_p,
|
| 241 |
+
temperature=self.temperature):
|
| 242 |
+
yield response
|
| 243 |
+
|
| 244 |
|
| 245 |
def get_model(
|
| 246 |
model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
|
|
|
|
| 248 |
msg = f"模型设置为了: {model_name}"
|
| 249 |
logging.info(msg)
|
| 250 |
model_type = ModelType.get_type(model_name)
|
| 251 |
+
del model
|
| 252 |
if model_type == ModelType.OpenAI:
|
| 253 |
model = OpenAIClient(
|
| 254 |
model_name=model_name,
|
|
|
|
| 257 |
temperature=temperature,
|
| 258 |
top_p=top_p,
|
| 259 |
)
|
| 260 |
+
elif model_type == ModelType.ChatGLM:
|
| 261 |
+
model = ChatGLM_Client(model_name)
|
| 262 |
return model, msg
|
| 263 |
|
| 264 |
|
modules/presets.py
CHANGED
|
@@ -57,6 +57,9 @@ MODELS = [
|
|
| 57 |
"gpt-4-0314",
|
| 58 |
"gpt-4-32k",
|
| 59 |
"gpt-4-32k-0314",
|
|
|
|
|
|
|
|
|
|
| 60 |
] # 可选的模型
|
| 61 |
|
| 62 |
MODEL_TOKEN_LIMIT = {
|
|
|
|
| 57 |
"gpt-4-0314",
|
| 58 |
"gpt-4-32k",
|
| 59 |
"gpt-4-32k-0314",
|
| 60 |
+
"chatglm-6b",
|
| 61 |
+
"chatglm-6b-int4",
|
| 62 |
+
"chatglm-6b-int4-qe"
|
| 63 |
] # 可选的模型
|
| 64 |
|
| 65 |
MODEL_TOKEN_LIMIT = {
|
requirements.txt
CHANGED
|
@@ -13,3 +13,5 @@ markdown
|
|
| 13 |
PyPDF2
|
| 14 |
pdfplumber
|
| 15 |
pandas
|
|
|
|
|
|
|
|
|
| 13 |
PyPDF2
|
| 14 |
pdfplumber
|
| 15 |
pandas
|
| 16 |
+
transformers
|
| 17 |
+
torch
|