Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
8c04739
1
Parent(s):
4b9ef74
feat: Azure OpenAI API 支持 embedding
Browse files- config_example.json +5 -2
- modules/config.py +39 -25
- modules/index_func.py +12 -4
- modules/models/azure.py +1 -1
config_example.json
CHANGED
|
@@ -9,10 +9,13 @@
|
|
| 9 |
"minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
|
| 10 |
|
| 11 |
//== Azure ==
|
|
|
|
| 12 |
"azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
|
| 13 |
-
"
|
| 14 |
"azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
|
| 15 |
-
"azure_deployment_name": "", // 你的 Azure
|
|
|
|
|
|
|
| 16 |
|
| 17 |
//== 基础配置 ==
|
| 18 |
"language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
|
|
|
|
| 9 |
"minimax_group_id": "", // 你的 MiniMax Group ID,用于 MiniMax 对话模型
|
| 10 |
|
| 11 |
//== Azure ==
|
| 12 |
+
"openai_api_type": "openai", // 可选项:azure, openai
|
| 13 |
"azure_openai_api_key": "", // 你的 Azure OpenAI API Key,用于 Azure OpenAI 对话模型
|
| 14 |
+
"azure_openai_api_base_url": "", // 你的 Azure Base URL
|
| 15 |
"azure_openai_api_version": "2023-05-15", // 你的 Azure OpenAI API 版本
|
| 16 |
+
"azure_deployment_name": "", // 你的 Azure OpenAI Chat 模型 Deployment 名称
|
| 17 |
+
"azure_embedding_deployment_name": "", // 你的 Azure OpenAI Embedding 模型 Deployment 名称
|
| 18 |
+
"azure_embedding_model_name": "text-embedding-ada-002", // 你的 Azure OpenAI Embedding 模型名称
|
| 19 |
|
| 20 |
//== 基础配置 ==
|
| 21 |
"language": "auto", // 界面语言,可选"auto", "zh-CN", "en-US", "ja-JP", "ko-KR"
|
modules/config.py
CHANGED
|
@@ -39,19 +39,22 @@ if os.path.exists("config.json"):
|
|
| 39 |
else:
|
| 40 |
config = {}
|
| 41 |
|
|
|
|
| 42 |
def load_config_to_environ(key_list):
|
| 43 |
global config
|
| 44 |
for key in key_list:
|
| 45 |
if key in config:
|
| 46 |
os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
|
| 47 |
|
|
|
|
| 48 |
sensitive_id = config.get("sensitive_id", "")
|
| 49 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
| 50 |
|
| 51 |
lang_config = config.get("language", "auto")
|
| 52 |
language = os.environ.get("LANGUAGE", lang_config)
|
| 53 |
|
| 54 |
-
hide_history_when_not_logged_in = config.get(
|
|
|
|
| 55 |
check_update = config.get("check_update", True)
|
| 56 |
show_api_billing = config.get("show_api_billing", False)
|
| 57 |
show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
|
|
@@ -68,31 +71,32 @@ if os.path.exists("auth.json"):
|
|
| 68 |
logging.info("检测到auth.json文件,正在进行迁移...")
|
| 69 |
auth_list = []
|
| 70 |
with open("auth.json", "r", encoding='utf-8') as f:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
config["users"] = auth_list
|
| 79 |
os.rename("auth.json", "auth(deprecated).json")
|
| 80 |
with open("config.json", "w", encoding='utf-8') as f:
|
| 81 |
json.dump(config, f, indent=4, ensure_ascii=False)
|
| 82 |
|
| 83 |
-
|
| 84 |
dockerflag = config.get("dockerflag", False)
|
| 85 |
if os.environ.get("dockerrun") == "yes":
|
| 86 |
dockerflag = True
|
| 87 |
|
| 88 |
-
|
| 89 |
my_api_key = config.get("openai_api_key", "")
|
| 90 |
my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
|
| 91 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
| 92 |
os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
|
| 93 |
|
| 94 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
| 95 |
-
google_palm_api_key = os.environ.get(
|
|
|
|
| 96 |
os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
|
| 97 |
|
| 98 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
|
@@ -103,13 +107,14 @@ os.environ["MINIMAX_API_KEY"] = minimax_api_key
|
|
| 103 |
minimax_group_id = config.get("minimax_group_id", "")
|
| 104 |
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
|
| 105 |
|
| 106 |
-
load_config_to_environ(["
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
|
| 110 |
|
| 111 |
-
|
| 112 |
-
multi_api_key = config.get("multi_api_key", False)
|
| 113 |
if multi_api_key:
|
| 114 |
api_key_list = config.get("api_key_list", [])
|
| 115 |
if len(api_key_list) == 0:
|
|
@@ -117,23 +122,26 @@ if multi_api_key:
|
|
| 117 |
sys.exit(1)
|
| 118 |
shared.state.set_api_key_queue(api_key_list)
|
| 119 |
|
| 120 |
-
auth_list = config.get("users", [])
|
| 121 |
authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
|
| 122 |
|
| 123 |
# 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
|
| 124 |
-
api_host = os.environ.get(
|
|
|
|
| 125 |
if api_host is not None:
|
| 126 |
shared.state.set_api_host(api_host)
|
| 127 |
os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
|
| 128 |
logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
|
| 129 |
|
| 130 |
-
default_chuanhu_assistant_model = config.get(
|
|
|
|
| 131 |
for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
|
| 132 |
if config.get(x, None) is not None:
|
| 133 |
os.environ[x] = config[x]
|
| 134 |
|
|
|
|
| 135 |
@contextmanager
|
| 136 |
-
def retrieve_openai_api(api_key
|
| 137 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 138 |
if api_key is None:
|
| 139 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
|
@@ -143,14 +151,15 @@ def retrieve_openai_api(api_key = None):
|
|
| 143 |
yield api_key
|
| 144 |
os.environ["OPENAI_API_KEY"] = old_api_key
|
| 145 |
|
| 146 |
-
|
|
|
|
| 147 |
log_level = config.get("log_level", "INFO")
|
| 148 |
logging.basicConfig(
|
| 149 |
level=log_level,
|
| 150 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
| 151 |
)
|
| 152 |
|
| 153 |
-
|
| 154 |
http_proxy = os.environ.get("HTTP_PROXY", "")
|
| 155 |
https_proxy = os.environ.get("HTTPS_PROXY", "")
|
| 156 |
http_proxy = config.get("http_proxy", http_proxy)
|
|
@@ -160,7 +169,8 @@ https_proxy = config.get("https_proxy", https_proxy)
|
|
| 160 |
os.environ["HTTP_PROXY"] = ""
|
| 161 |
os.environ["HTTPS_PROXY"] = ""
|
| 162 |
|
| 163 |
-
local_embedding = config.get("local_embedding", False)
|
|
|
|
| 164 |
|
| 165 |
@contextmanager
|
| 166 |
def retrieve_proxy(proxy=None):
|
|
@@ -177,12 +187,13 @@ def retrieve_proxy(proxy=None):
|
|
| 177 |
old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
|
| 178 |
os.environ["HTTP_PROXY"] = http_proxy
|
| 179 |
os.environ["HTTPS_PROXY"] = https_proxy
|
| 180 |
-
yield http_proxy, https_proxy
|
| 181 |
|
| 182 |
# return old proxy
|
| 183 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
| 184 |
|
| 185 |
-
|
|
|
|
| 186 |
user_latex_option = config.get("latex_option", "default")
|
| 187 |
if user_latex_option == "default":
|
| 188 |
latex_delimiters_set = [
|
|
@@ -219,16 +230,19 @@ else:
|
|
| 219 |
{"left": "\\[", "right": "\\]", "display": True},
|
| 220 |
]
|
| 221 |
|
| 222 |
-
|
| 223 |
advance_docs = defaultdict(lambda: defaultdict(dict))
|
| 224 |
advance_docs.update(config.get("advance_docs", {}))
|
|
|
|
|
|
|
| 225 |
def update_doc_config(two_column_pdf):
|
| 226 |
global advance_docs
|
| 227 |
advance_docs["pdf"]["two_column"] = two_column_pdf
|
| 228 |
|
| 229 |
logging.info(f"更新后的文件参数为:{advance_docs}")
|
| 230 |
|
| 231 |
-
|
|
|
|
| 232 |
server_name = config.get("server_name", None)
|
| 233 |
server_port = config.get("server_port", None)
|
| 234 |
if server_name is None:
|
|
|
|
| 39 |
else:
|
| 40 |
config = {}
|
| 41 |
|
| 42 |
+
|
| 43 |
def load_config_to_environ(key_list):
|
| 44 |
global config
|
| 45 |
for key in key_list:
|
| 46 |
if key in config:
|
| 47 |
os.environ[key.upper()] = os.environ.get(key.upper(), config[key])
|
| 48 |
|
| 49 |
+
|
| 50 |
sensitive_id = config.get("sensitive_id", "")
|
| 51 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
| 52 |
|
| 53 |
lang_config = config.get("language", "auto")
|
| 54 |
language = os.environ.get("LANGUAGE", lang_config)
|
| 55 |
|
| 56 |
+
hide_history_when_not_logged_in = config.get(
|
| 57 |
+
"hide_history_when_not_logged_in", False)
|
| 58 |
check_update = config.get("check_update", True)
|
| 59 |
show_api_billing = config.get("show_api_billing", False)
|
| 60 |
show_api_billing = bool(os.environ.get("SHOW_API_BILLING", show_api_billing))
|
|
|
|
| 71 |
logging.info("检测到auth.json文件,正在进行迁移...")
|
| 72 |
auth_list = []
|
| 73 |
with open("auth.json", "r", encoding='utf-8') as f:
|
| 74 |
+
auth = json.load(f)
|
| 75 |
+
for _ in auth:
|
| 76 |
+
if auth[_]["username"] and auth[_]["password"]:
|
| 77 |
+
auth_list.append((auth[_]["username"], auth[_]["password"]))
|
| 78 |
+
else:
|
| 79 |
+
logging.error("请检查auth.json文件中的用户名和密码!")
|
| 80 |
+
sys.exit(1)
|
| 81 |
config["users"] = auth_list
|
| 82 |
os.rename("auth.json", "auth(deprecated).json")
|
| 83 |
with open("config.json", "w", encoding='utf-8') as f:
|
| 84 |
json.dump(config, f, indent=4, ensure_ascii=False)
|
| 85 |
|
| 86 |
+
# 处理docker if we are running in Docker
|
| 87 |
dockerflag = config.get("dockerflag", False)
|
| 88 |
if os.environ.get("dockerrun") == "yes":
|
| 89 |
dockerflag = True
|
| 90 |
|
| 91 |
+
# 处理 api-key 以及 允许的用户列表
|
| 92 |
my_api_key = config.get("openai_api_key", "")
|
| 93 |
my_api_key = os.environ.get("OPENAI_API_KEY", my_api_key)
|
| 94 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
| 95 |
os.environ["OPENAI_EMBEDDING_API_KEY"] = my_api_key
|
| 96 |
|
| 97 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
| 98 |
+
google_palm_api_key = os.environ.get(
|
| 99 |
+
"GOOGLE_PALM_API_KEY", google_palm_api_key)
|
| 100 |
os.environ["GOOGLE_PALM_API_KEY"] = google_palm_api_key
|
| 101 |
|
| 102 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
|
|
|
| 107 |
minimax_group_id = config.get("minimax_group_id", "")
|
| 108 |
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id
|
| 109 |
|
| 110 |
+
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
| 111 |
+
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
| 112 |
|
| 113 |
|
| 114 |
usage_limit = os.environ.get("USAGE_LIMIT", config.get("usage_limit", 120))
|
| 115 |
|
| 116 |
+
# 多账户机制
|
| 117 |
+
multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
|
| 118 |
if multi_api_key:
|
| 119 |
api_key_list = config.get("api_key_list", [])
|
| 120 |
if len(api_key_list) == 0:
|
|
|
|
| 122 |
sys.exit(1)
|
| 123 |
shared.state.set_api_key_queue(api_key_list)
|
| 124 |
|
| 125 |
+
auth_list = config.get("users", []) # 实际上是使用者的列表
|
| 126 |
authflag = len(auth_list) > 0 # 是否开启认证的状态值,改为判断auth_list长度
|
| 127 |
|
| 128 |
# 处理自定义的api_host,优先读环境变量的配置,如果存在则自动装配
|
| 129 |
+
api_host = os.environ.get(
|
| 130 |
+
"OPENAI_API_BASE", config.get("openai_api_base", None))
|
| 131 |
if api_host is not None:
|
| 132 |
shared.state.set_api_host(api_host)
|
| 133 |
os.environ["OPENAI_API_BASE"] = f"{api_host}/v1"
|
| 134 |
logging.info(f"OpenAI API Base set to: {os.environ['OPENAI_API_BASE']}")
|
| 135 |
|
| 136 |
+
default_chuanhu_assistant_model = config.get(
|
| 137 |
+
"default_chuanhu_assistant_model", "gpt-3.5-turbo")
|
| 138 |
for x in ["GOOGLE_CSE_ID", "GOOGLE_API_KEY", "WOLFRAM_ALPHA_APPID", "SERPAPI_API_KEY"]:
|
| 139 |
if config.get(x, None) is not None:
|
| 140 |
os.environ[x] = config[x]
|
| 141 |
|
| 142 |
+
|
| 143 |
@contextmanager
|
| 144 |
+
def retrieve_openai_api(api_key=None):
|
| 145 |
old_api_key = os.environ.get("OPENAI_API_KEY", "")
|
| 146 |
if api_key is None:
|
| 147 |
os.environ["OPENAI_API_KEY"] = my_api_key
|
|
|
|
| 151 |
yield api_key
|
| 152 |
os.environ["OPENAI_API_KEY"] = old_api_key
|
| 153 |
|
| 154 |
+
|
| 155 |
+
# 处理log
|
| 156 |
log_level = config.get("log_level", "INFO")
|
| 157 |
logging.basicConfig(
|
| 158 |
level=log_level,
|
| 159 |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
| 160 |
)
|
| 161 |
|
| 162 |
+
# 处理代理:
|
| 163 |
http_proxy = os.environ.get("HTTP_PROXY", "")
|
| 164 |
https_proxy = os.environ.get("HTTPS_PROXY", "")
|
| 165 |
http_proxy = config.get("http_proxy", http_proxy)
|
|
|
|
| 169 |
os.environ["HTTP_PROXY"] = ""
|
| 170 |
os.environ["HTTPS_PROXY"] = ""
|
| 171 |
|
| 172 |
+
local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
|
| 173 |
+
|
| 174 |
|
| 175 |
@contextmanager
|
| 176 |
def retrieve_proxy(proxy=None):
|
|
|
|
| 187 |
old_var = os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"]
|
| 188 |
os.environ["HTTP_PROXY"] = http_proxy
|
| 189 |
os.environ["HTTPS_PROXY"] = https_proxy
|
| 190 |
+
yield http_proxy, https_proxy # return new proxy
|
| 191 |
|
| 192 |
# return old proxy
|
| 193 |
os.environ["HTTP_PROXY"], os.environ["HTTPS_PROXY"] = old_var
|
| 194 |
|
| 195 |
+
|
| 196 |
+
# 处理latex options
|
| 197 |
user_latex_option = config.get("latex_option", "default")
|
| 198 |
if user_latex_option == "default":
|
| 199 |
latex_delimiters_set = [
|
|
|
|
| 230 |
{"left": "\\[", "right": "\\]", "display": True},
|
| 231 |
]
|
| 232 |
|
| 233 |
+
# 处理advance docs
|
| 234 |
advance_docs = defaultdict(lambda: defaultdict(dict))
|
| 235 |
advance_docs.update(config.get("advance_docs", {}))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
def update_doc_config(two_column_pdf):
|
| 239 |
global advance_docs
|
| 240 |
advance_docs["pdf"]["two_column"] = two_column_pdf
|
| 241 |
|
| 242 |
logging.info(f"更新后的文件参数为:{advance_docs}")
|
| 243 |
|
| 244 |
+
|
| 245 |
+
# 处理gradio.launch参数
|
| 246 |
server_name = config.get("server_name", None)
|
| 247 |
server_port = config.get("server_port", None)
|
| 248 |
if server_name is None:
|
modules/index_func.py
CHANGED
|
@@ -51,7 +51,8 @@ def get_documents(file_src):
|
|
| 51 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
| 52 |
for page in tqdm(pdfReader.pages):
|
| 53 |
pdftext += page.extract_text()
|
| 54 |
-
texts = [Document(page_content=pdftext,
|
|
|
|
| 55 |
elif file_type == ".docx":
|
| 56 |
logging.debug("Loading Word...")
|
| 57 |
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
|
@@ -72,7 +73,8 @@ def get_documents(file_src):
|
|
| 72 |
text_list = excel_to_string(filepath)
|
| 73 |
texts = []
|
| 74 |
for elem in text_list:
|
| 75 |
-
texts.append(Document(page_content=elem,
|
|
|
|
| 76 |
else:
|
| 77 |
logging.debug("Loading text file...")
|
| 78 |
from langchain.document_loaders import TextLoader
|
|
@@ -115,10 +117,16 @@ def construct_index(
|
|
| 115 |
index_path = f"./index/{index_name}"
|
| 116 |
if local_embedding:
|
| 117 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
| 118 |
-
embeddings = HuggingFaceEmbeddings(
|
|
|
|
| 119 |
else:
|
| 120 |
from langchain.embeddings import OpenAIEmbeddings
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
if os.path.exists(index_path):
|
| 123 |
logging.info("找到了缓存的索引文件,加载中……")
|
| 124 |
return FAISS.load_local(index_path, embeddings)
|
|
|
|
| 51 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
| 52 |
for page in tqdm(pdfReader.pages):
|
| 53 |
pdftext += page.extract_text()
|
| 54 |
+
texts = [Document(page_content=pdftext,
|
| 55 |
+
metadata={"source": filepath})]
|
| 56 |
elif file_type == ".docx":
|
| 57 |
logging.debug("Loading Word...")
|
| 58 |
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
|
|
|
| 73 |
text_list = excel_to_string(filepath)
|
| 74 |
texts = []
|
| 75 |
for elem in text_list:
|
| 76 |
+
texts.append(Document(page_content=elem,
|
| 77 |
+
metadata={"source": filepath}))
|
| 78 |
else:
|
| 79 |
logging.debug("Loading text file...")
|
| 80 |
from langchain.document_loaders import TextLoader
|
|
|
|
| 117 |
index_path = f"./index/{index_name}"
|
| 118 |
if local_embedding:
|
| 119 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
| 120 |
+
embeddings = HuggingFaceEmbeddings(
|
| 121 |
+
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2")
|
| 122 |
else:
|
| 123 |
from langchain.embeddings import OpenAIEmbeddings
|
| 124 |
+
if os.environ.get("OPENAI_API_TYPE", "openai") == "openai":
|
| 125 |
+
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get(
|
| 126 |
+
"OPENAI_API_BASE", None), openai_api_key=os.environ.get("OPENAI_EMBEDDING_API_KEY", api_key))
|
| 127 |
+
else:
|
| 128 |
+
embeddings = OpenAIEmbeddings(deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"], openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
| 129 |
+
model=os.environ["AZURE_EMBEDDING_MODEL_NAME"], openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"], openai_api_type="azure")
|
| 130 |
if os.path.exists(index_path):
|
| 131 |
logging.info("找到了缓存的索引文件,加载中……")
|
| 132 |
return FAISS.load_local(index_path, embeddings)
|
modules/models/azure.py
CHANGED
|
@@ -9,7 +9,7 @@ class Azure_OpenAI_Client(Base_Chat_Langchain_Client):
|
|
| 9 |
def setup_model(self):
|
| 10 |
# inplement this to setup the model then return it
|
| 11 |
return AzureChatOpenAI(
|
| 12 |
-
openai_api_base=os.environ["
|
| 13 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
| 14 |
deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
|
| 15 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|
|
|
|
| 9 |
def setup_model(self):
|
| 10 |
# inplement this to setup the model then return it
|
| 11 |
return AzureChatOpenAI(
|
| 12 |
+
openai_api_base=os.environ["AZURE_OPENAI_API_BASE_URL"],
|
| 13 |
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
| 14 |
deployment_name=os.environ["AZURE_DEPLOYMENT_NAME"],
|
| 15 |
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
|