Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
432eb42
1
Parent(s):
b5ddb7e
feat: 支持添加训练好的模型到配置文件里
Browse files- ChuanhuChatbot.py +1 -1
- modules/config.py +4 -0
- modules/train_func.py +13 -1
ChuanhuChatbot.py
CHANGED
|
@@ -15,8 +15,8 @@ from modules.presets import *
|
|
| 15 |
from modules.overwrites import *
|
| 16 |
from modules.webui import *
|
| 17 |
from modules.repo import *
|
|
|
|
| 18 |
from modules.models.models import get_model
|
| 19 |
-
from modules.train_func import handle_dataset_selection, handle_dataset_clear, upload_to_openai, start_training, get_training_status, add_to_models, cancel_all_jobs
|
| 20 |
|
| 21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 22 |
|
|
|
|
| 15 |
from modules.overwrites import *
|
| 16 |
from modules.webui import *
|
| 17 |
from modules.repo import *
|
| 18 |
+
from modules.train_func import *
|
| 19 |
from modules.models.models import get_model
|
|
|
|
| 20 |
|
| 21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 22 |
|
modules/config.py
CHANGED
|
@@ -96,6 +96,10 @@ else:
|
|
| 96 |
sensitive_id = config.get("sensitive_id", "")
|
| 97 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
| 101 |
google_palm_api_key = os.environ.get(
|
|
|
|
| 96 |
sensitive_id = config.get("sensitive_id", "")
|
| 97 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
| 98 |
|
| 99 |
+
# 模型配置
|
| 100 |
+
if "extra_models" in config:
|
| 101 |
+
presets.MODELS.extend(config["extra_models"])
|
| 102 |
+
logging.info(f"已添加额外的模型:{config['extra_models']}")
|
| 103 |
|
| 104 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
| 105 |
google_palm_api_key = os.environ.get(
|
modules/train_func.py
CHANGED
|
@@ -5,6 +5,7 @@ import traceback
|
|
| 5 |
import openai
|
| 6 |
import gradio as gr
|
| 7 |
import ujson as json
|
|
|
|
| 8 |
|
| 9 |
import modules.presets as presets
|
| 10 |
from modules.utils import get_file_hash
|
|
@@ -112,7 +113,18 @@ def handle_dataset_clear():
|
|
| 112 |
def add_to_models():
|
| 113 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 114 |
succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
|
| 117 |
|
| 118 |
def cancel_all_jobs():
|
|
|
|
| 5 |
import openai
|
| 6 |
import gradio as gr
|
| 7 |
import ujson as json
|
| 8 |
+
import commentjson
|
| 9 |
|
| 10 |
import modules.presets as presets
|
| 11 |
from modules.utils import get_file_hash
|
|
|
|
| 113 |
def add_to_models():
|
| 114 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 115 |
succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
|
| 116 |
+
extra_models = [job["fine_tuned_model"] for job in succeeded_jobs]
|
| 117 |
+
presets.MODELS.extend(extra_models)
|
| 118 |
+
|
| 119 |
+
with open('config.json', 'r') as f:
|
| 120 |
+
data = commentjson.load(f)
|
| 121 |
+
if 'extra_models' in data:
|
| 122 |
+
data['extra_models'].extend(extra_models)
|
| 123 |
+
else:
|
| 124 |
+
data['extra_models'] = extra_models
|
| 125 |
+
with open('config.json', 'w') as f:
|
| 126 |
+
commentjson.dump(data, f, indent=4)
|
| 127 |
+
|
| 128 |
return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
|
| 129 |
|
| 130 |
def cancel_all_jobs():
|