FaYo
commited on
Commit
·
9f68218
1
Parent(s):
5661e58
model
Browse files
finetune_configs/internlm_chat_7b_qlora_alpace_e3.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from mmengine.dataset import DefaultSampler
|
| 5 |
+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
| 6 |
+
LoggerHook, ParamSchedulerHook)
|
| 7 |
+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
| 8 |
+
from peft import LoraConfig
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
| 11 |
+
BitsAndBytesConfig)
|
| 12 |
+
|
| 13 |
+
from xtuner.dataset import process_hf_dataset
|
| 14 |
+
from xtuner.dataset.collate_fns import default_collate_fn
|
| 15 |
+
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
| 16 |
+
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
| 17 |
+
VarlenAttnArgsToMessageHubHook)
|
| 18 |
+
from xtuner.engine.runner import TrainLoop
|
| 19 |
+
from xtuner.model import SupervisedFinetune
|
| 20 |
+
from xtuner.parallel.sequence import SequenceParallelSampler
|
| 21 |
+
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
| 22 |
+
|
| 23 |
+
#######################################################################
|
| 24 |
+
# PART 1 Settings #
|
| 25 |
+
#######################################################################
|
| 26 |
+
# Model
|
| 27 |
+
pretrained_model_name_or_path = '/group_share/lntelligent-Medical-Guidance-Large-Model/InternLM/XTuner/model/internlm2-chat-7b'
|
| 28 |
+
use_varlen_attn = False
|
| 29 |
+
|
| 30 |
+
# Data
|
| 31 |
+
alpaca_en_path = 'dataset/gen_dataset/train_dataset/90_train.jsonl'
|
| 32 |
+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
| 33 |
+
max_length = 2048
|
| 34 |
+
pack_to_max_length = True
|
| 35 |
+
|
| 36 |
+
# parallel
|
| 37 |
+
sequence_parallel_size = 1
|
| 38 |
+
|
| 39 |
+
# Scheduler & Optimizer
|
| 40 |
+
batch_size = 1 # per_device
|
| 41 |
+
accumulative_counts = 16
|
| 42 |
+
accumulative_counts *= sequence_parallel_size
|
| 43 |
+
dataloader_num_workers = 0
|
| 44 |
+
max_epochs = 3
|
| 45 |
+
optim_type = AdamW
|
| 46 |
+
lr = 2e-4
|
| 47 |
+
betas = (0.9, 0.999)
|
| 48 |
+
weight_decay = 0
|
| 49 |
+
max_norm = 1 # grad clip
|
| 50 |
+
warmup_ratio = 0.03
|
| 51 |
+
|
| 52 |
+
# Save
|
| 53 |
+
save_steps = 500
|
| 54 |
+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
| 55 |
+
|
| 56 |
+
# Evaluate the generation performance during the training
|
| 57 |
+
evaluation_freq = 500
|
| 58 |
+
SYSTEM = SYSTEM_TEMPLATE.alpaca
|
| 59 |
+
evaluation_inputs = [
|
| 60 |
+
'请介绍一下你自己', 'Please introduce yourself'
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
#######################################################################
|
| 64 |
+
# PART 2 Model & Tokenizer #
|
| 65 |
+
#######################################################################
|
| 66 |
+
tokenizer = dict(
|
| 67 |
+
type=AutoTokenizer.from_pretrained,
|
| 68 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 69 |
+
trust_remote_code=True,
|
| 70 |
+
padding_side='right')
|
| 71 |
+
|
| 72 |
+
model = dict(
|
| 73 |
+
type=SupervisedFinetune,
|
| 74 |
+
use_varlen_attn=use_varlen_attn,
|
| 75 |
+
llm=dict(
|
| 76 |
+
type=AutoModelForCausalLM.from_pretrained,
|
| 77 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
| 78 |
+
trust_remote_code=True,
|
| 79 |
+
torch_dtype=torch.float16,
|
| 80 |
+
quantization_config=dict(
|
| 81 |
+
type=BitsAndBytesConfig,
|
| 82 |
+
load_in_4bit=True,
|
| 83 |
+
load_in_8bit=False,
|
| 84 |
+
llm_int8_threshold=6.0,
|
| 85 |
+
llm_int8_has_fp16_weight=False,
|
| 86 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 87 |
+
bnb_4bit_use_double_quant=True,
|
| 88 |
+
bnb_4bit_quant_type='nf4')),
|
| 89 |
+
lora=dict(
|
| 90 |
+
type=LoraConfig,
|
| 91 |
+
r=64,
|
| 92 |
+
lora_alpha=16,
|
| 93 |
+
lora_dropout=0.1,
|
| 94 |
+
bias='none',
|
| 95 |
+
task_type='CAUSAL_LM'))
|
| 96 |
+
|
| 97 |
+
#######################################################################
|
| 98 |
+
# PART 3 Dataset & Dataloader #
|
| 99 |
+
#######################################################################
|
| 100 |
+
alpaca_en = dict(
|
| 101 |
+
type=process_hf_dataset,
|
| 102 |
+
dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
|
| 103 |
+
tokenizer=tokenizer,
|
| 104 |
+
max_length=max_length,
|
| 105 |
+
dataset_map_fn=None,
|
| 106 |
+
template_map_fn=dict(
|
| 107 |
+
type=template_map_fn_factory, template=prompt_template),
|
| 108 |
+
remove_unused_columns=True,
|
| 109 |
+
shuffle_before_pack=True,
|
| 110 |
+
pack_to_max_length=pack_to_max_length,
|
| 111 |
+
use_varlen_attn=use_varlen_attn)
|
| 112 |
+
|
| 113 |
+
sampler = SequenceParallelSampler \
|
| 114 |
+
if sequence_parallel_size > 1 else DefaultSampler
|
| 115 |
+
train_dataloader = dict(
|
| 116 |
+
batch_size=batch_size,
|
| 117 |
+
num_workers=dataloader_num_workers,
|
| 118 |
+
dataset=alpaca_en,
|
| 119 |
+
sampler=dict(type=sampler, shuffle=True),
|
| 120 |
+
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
| 121 |
+
|
| 122 |
+
#######################################################################
|
| 123 |
+
# PART 4 Scheduler & Optimizer #
|
| 124 |
+
#######################################################################
|
| 125 |
+
# optimizer
|
| 126 |
+
optim_wrapper = dict(
|
| 127 |
+
type=AmpOptimWrapper,
|
| 128 |
+
optimizer=dict(
|
| 129 |
+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
| 130 |
+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
| 131 |
+
accumulative_counts=accumulative_counts,
|
| 132 |
+
loss_scale='dynamic',
|
| 133 |
+
dtype='float16')
|
| 134 |
+
|
| 135 |
+
# learning policy
|
| 136 |
+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
| 137 |
+
param_scheduler = [
|
| 138 |
+
dict(
|
| 139 |
+
type=LinearLR,
|
| 140 |
+
start_factor=1e-5,
|
| 141 |
+
by_epoch=True,
|
| 142 |
+
begin=0,
|
| 143 |
+
end=warmup_ratio * max_epochs,
|
| 144 |
+
convert_to_iter_based=True),
|
| 145 |
+
dict(
|
| 146 |
+
type=CosineAnnealingLR,
|
| 147 |
+
eta_min=0.0,
|
| 148 |
+
by_epoch=True,
|
| 149 |
+
begin=warmup_ratio * max_epochs,
|
| 150 |
+
end=max_epochs,
|
| 151 |
+
convert_to_iter_based=True)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
# train, val, test setting
|
| 155 |
+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
| 156 |
+
|
| 157 |
+
#######################################################################
|
| 158 |
+
# PART 5 Runtime #
|
| 159 |
+
#######################################################################
|
| 160 |
+
# Log the dialogue periodically during the training process, optional
|
| 161 |
+
custom_hooks = [
|
| 162 |
+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
| 163 |
+
dict(
|
| 164 |
+
type=EvaluateChatHook,
|
| 165 |
+
tokenizer=tokenizer,
|
| 166 |
+
every_n_iters=evaluation_freq,
|
| 167 |
+
evaluation_inputs=evaluation_inputs,
|
| 168 |
+
system=SYSTEM,
|
| 169 |
+
prompt_template=prompt_template)
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
if use_varlen_attn:
|
| 173 |
+
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
| 174 |
+
|
| 175 |
+
# configure default hooks
|
| 176 |
+
default_hooks = dict(
|
| 177 |
+
# record the time of every iteration.
|
| 178 |
+
timer=dict(type=IterTimerHook),
|
| 179 |
+
# print log every 10 iterations.
|
| 180 |
+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
| 181 |
+
# enable the parameter scheduler.
|
| 182 |
+
param_scheduler=dict(type=ParamSchedulerHook),
|
| 183 |
+
# save checkpoint per `save_steps`.
|
| 184 |
+
checkpoint=dict(
|
| 185 |
+
type=CheckpointHook,
|
| 186 |
+
by_epoch=False,
|
| 187 |
+
interval=save_steps,
|
| 188 |
+
max_keep_ckpts=save_total_limit),
|
| 189 |
+
# set sampler seed in distributed evrionment.
|
| 190 |
+
sampler_seed=dict(type=DistSamplerSeedHook),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# configure environment
|
| 194 |
+
env_cfg = dict(
|
| 195 |
+
# whether to enable cudnn benchmark
|
| 196 |
+
cudnn_benchmark=False,
|
| 197 |
+
# set multi process parameters
|
| 198 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
| 199 |
+
# set distributed parameters
|
| 200 |
+
dist_cfg=dict(backend='nccl'),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# set visualizer
|
| 204 |
+
visualizer = None
|
| 205 |
+
|
| 206 |
+
# set log level
|
| 207 |
+
log_level = 'INFO'
|
| 208 |
+
|
| 209 |
+
# load from which checkpoint
|
| 210 |
+
load_from = None
|
| 211 |
+
|
| 212 |
+
# whether to resume training from the loaded checkpoint
|
| 213 |
+
resume = False
|
| 214 |
+
|
| 215 |
+
# Defaults to use random seed and disable `deterministic`
|
| 216 |
+
randomness = dict(seed=None, deterministic=False)
|
| 217 |
+
|
| 218 |
+
# set log processor
|
| 219 |
+
log_processor = dict(by_epoch=False)
|
pages/selling_page.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# @Time : 2024.4.16
|
| 4 |
+
# @Author : HinGwenWong
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import streamlit as st
|
| 11 |
+
|
| 12 |
+
from utils.web_configs import WEB_CONFIGS
|
| 13 |
+
|
| 14 |
+
# 设置页面配置,包括标题、图标、布局和菜单项
|
| 15 |
+
st.set_page_config(
|
| 16 |
+
page_title="智能医导大模型",
|
| 17 |
+
page_icon="🛒",
|
| 18 |
+
layout="wide",
|
| 19 |
+
initial_sidebar_state="expanded",
|
| 20 |
+
menu_items={
|
| 21 |
+
"Get Help": "https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model/tree/main",
|
| 22 |
+
"About": "# 智能医导大模型",
|
| 23 |
+
},
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from audiorecorder import audiorecorder
|
| 27 |
+
|
| 28 |
+
from utils.asr.asr_worker import process_asr
|
| 29 |
+
from utils.digital_human.digital_human_worker import show_video
|
| 30 |
+
from utils.infer.lmdeploy_infer import get_turbomind_response
|
| 31 |
+
from utils.model_loader import ASR_HANDLER, LLM_MODEL, RAG_RETRIEVER
|
| 32 |
+
from utils.tools import resize_image
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def on_btn_click(*args, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
处理按钮点击事件的函数。
|
| 38 |
+
"""
|
| 39 |
+
if kwargs["info"] == "清除对话历史":
|
| 40 |
+
st.session_state.messages = []
|
| 41 |
+
elif kwargs["info"] == "返回科室页":
|
| 42 |
+
st.session_state.page_switch = "app.py"
|
| 43 |
+
else:
|
| 44 |
+
st.session_state.button_msg = kwargs["info"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def init_sidebar():
|
| 48 |
+
"""
|
| 49 |
+
初始化侧边栏界面,展示商品信息,并提供操作按钮。
|
| 50 |
+
"""
|
| 51 |
+
asr_text = ""
|
| 52 |
+
with st.sidebar:
|
| 53 |
+
# 标题
|
| 54 |
+
st.markdown("## 智能医导大模型")
|
| 55 |
+
st.markdown("[智能医导大模型](https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model)")
|
| 56 |
+
st.subheader("功能点:", divider="grey")
|
| 57 |
+
# st.markdown(
|
| 58 |
+
# "1. 📜 **主播文案一键生成**\n2. 🚀 KV cache + Turbomind **推理加速**\n3. 📚 RAG **检索增强生成**\n4. 🔊 TTS **文字转语音**\n5. 🦸 **数字人生成**\n6. 🌐 **Agent 网络查询**\n7. 🎙️ **ASR 语音转文字**"
|
| 59 |
+
# )
|
| 60 |
+
|
| 61 |
+
st.subheader("目前讲解")
|
| 62 |
+
with st.container(height=400, border=True):
|
| 63 |
+
st.subheader(st.session_state.product_name)
|
| 64 |
+
|
| 65 |
+
image = resize_image(st.session_state.image_path, max_height=100)
|
| 66 |
+
st.image(image, channels="bgr")
|
| 67 |
+
|
| 68 |
+
st.subheader("科室特点", divider="grey")
|
| 69 |
+
st.markdown(st.session_state.hightlight)
|
| 70 |
+
|
| 71 |
+
want_to_buy_list = [
|
| 72 |
+
"我打算买了。",
|
| 73 |
+
"我准备入手了。",
|
| 74 |
+
"我决定要买了。",
|
| 75 |
+
"我准备下单了。",
|
| 76 |
+
"我将要购买这款产品。",
|
| 77 |
+
"我准备买下来了。",
|
| 78 |
+
"我准备将这个买下。",
|
| 79 |
+
"我准备要购买了。",
|
| 80 |
+
"我决定买下它。",
|
| 81 |
+
"我准备将其买下。",
|
| 82 |
+
]
|
| 83 |
+
buy_flag = st.button("加入信息🛒", on_click=on_btn_click, kwargs={"info": random.choice(want_to_buy_list)})
|
| 84 |
+
|
| 85 |
+
# TODO 加入卖货信息
|
| 86 |
+
# 卖出 xxx 个
|
| 87 |
+
# 成交额
|
| 88 |
+
|
| 89 |
+
if WEB_CONFIGS.ENABLE_ASR:
|
| 90 |
+
Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).mkdir(parents=True, exist_ok=True)
|
| 91 |
+
|
| 92 |
+
st.subheader(f"语音输入", divider="grey")
|
| 93 |
+
audio = audiorecorder(
|
| 94 |
+
start_prompt="开始录音", stop_prompt="停止录音", pause_prompt="", show_visualizer=True, key=None
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if len(audio) > 0:
|
| 98 |
+
|
| 99 |
+
# 将录音保存 wav 文件
|
| 100 |
+
save_tag = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".wav"
|
| 101 |
+
wav_path = str(Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).joinpath(save_tag).absolute())
|
| 102 |
+
|
| 103 |
+
# st.audio(audio.export().read()) # 前端显示
|
| 104 |
+
audio.export(wav_path, format="wav") # 使用 pydub 保存到 wav 文件
|
| 105 |
+
|
| 106 |
+
# To get audio properties, use pydub AudioSegment properties:
|
| 107 |
+
# st.write(
|
| 108 |
+
# f"Frame rate: {audio.frame_rate}, Frame width: {audio.frame_width}, Duration: {audio.duration_seconds} seconds"
|
| 109 |
+
# )
|
| 110 |
+
|
| 111 |
+
# 语音识别
|
| 112 |
+
asr_text = process_asr(ASR_HANDLER, wav_path)
|
| 113 |
+
|
| 114 |
+
# 删除过程文件
|
| 115 |
+
# Path(wav_path).unlink()
|
| 116 |
+
|
| 117 |
+
# 是否生成 TTS
|
| 118 |
+
if WEB_CONFIGS.ENABLE_TTS:
|
| 119 |
+
st.subheader("TTS 配置", divider="grey")
|
| 120 |
+
st.session_state.gen_tts_checkbox = st.toggle("生成语音", value=st.session_state.gen_tts_checkbox)
|
| 121 |
+
|
| 122 |
+
if WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
|
| 123 |
+
# 是否生成 数字人
|
| 124 |
+
st.subheader(f"数字人 配置", divider="grey")
|
| 125 |
+
st.session_state.gen_digital_human_checkbox = st.toggle(
|
| 126 |
+
"生成数字人视频", value=st.session_state.gen_digital_human_checkbox
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if WEB_CONFIGS.ENABLE_AGENT:
|
| 130 |
+
# 是否使用 agent
|
| 131 |
+
st.subheader(f"Agent 配置", divider="grey")
|
| 132 |
+
with st.container(border=True):
|
| 133 |
+
st.markdown("**插件列表**")
|
| 134 |
+
st.button("结合天气查询到货时间", type="primary")
|
| 135 |
+
st.session_state.enable_agent_checkbox = st.toggle("使用 Agent 能力", value=st.session_state.enable_agent_checkbox)
|
| 136 |
+
|
| 137 |
+
st.subheader("页面切换", divider="grey")
|
| 138 |
+
st.button("返回科室页", on_click=on_btn_click, kwargs={"info": "返回科室页"})
|
| 139 |
+
|
| 140 |
+
st.subheader("对话设置", divider="grey")
|
| 141 |
+
st.button("清除对话历史", on_click=on_btn_click, kwargs={"info": "清除对话历史"})
|
| 142 |
+
|
| 143 |
+
# 模型配置
|
| 144 |
+
# st.markdown("## 模型配置")
|
| 145 |
+
# max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
|
| 146 |
+
# top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
| 147 |
+
# temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
| 148 |
+
|
| 149 |
+
return asr_text
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def init_message_block(meta_instruction, user_avator, robot_avator):
|
| 153 |
+
|
| 154 |
+
# 在应用重新运行时显示聊天历史消息
|
| 155 |
+
for message in st.session_state.messages:
|
| 156 |
+
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
| 157 |
+
st.markdown(message["content"])
|
| 158 |
+
|
| 159 |
+
if message.get("wav") is not None:
|
| 160 |
+
# 展示语音
|
| 161 |
+
print(f"Load wav {message['wav']}")
|
| 162 |
+
with open(message["wav"], "rb") as f_wav:
|
| 163 |
+
audio_bytes = f_wav.read()
|
| 164 |
+
st.audio(audio_bytes, format="audio/wav")
|
| 165 |
+
|
| 166 |
+
# 如果聊天历史为空,则显示产品介绍
|
| 167 |
+
if len(st.session_state.messages) == 0:
|
| 168 |
+
# 直接产品介绍
|
| 169 |
+
get_turbomind_response(
|
| 170 |
+
st.session_state.first_input,
|
| 171 |
+
meta_instruction,
|
| 172 |
+
user_avator,
|
| 173 |
+
robot_avator,
|
| 174 |
+
LLM_MODEL,
|
| 175 |
+
session_messages=st.session_state.messages,
|
| 176 |
+
add_session_msg=False,
|
| 177 |
+
first_input_str="",
|
| 178 |
+
enable_agent=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# 初始化按钮消息状态
|
| 182 |
+
if "button_msg" not in st.session_state:
|
| 183 |
+
st.session_state.button_msg = "x-x"
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def process_message(user_avator, prompt, meta_instruction, robot_avator):
|
| 187 |
+
# Display user message in chat message container
|
| 188 |
+
with st.chat_message("user", avatar=user_avator):
|
| 189 |
+
st.markdown(prompt)
|
| 190 |
+
|
| 191 |
+
get_turbomind_response(
|
| 192 |
+
prompt,
|
| 193 |
+
meta_instruction,
|
| 194 |
+
user_avator,
|
| 195 |
+
robot_avator,
|
| 196 |
+
LLM_MODEL,
|
| 197 |
+
session_messages=st.session_state.messages,
|
| 198 |
+
add_session_msg=True,
|
| 199 |
+
first_input_str=st.session_state.first_input,
|
| 200 |
+
rag_retriever=RAG_RETRIEVER,
|
| 201 |
+
product_name=st.session_state.product_name,
|
| 202 |
+
enable_agent=st.session_state.enable_agent_checkbox,
|
| 203 |
+
# departure_place=st.session_state.departure_place,
|
| 204 |
+
# delivery_company_name=st.session_state.delivery_company_name,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main(meta_instruction):
|
| 209 |
+
|
| 210 |
+
# 检查页面切换状态并进行切换
|
| 211 |
+
if st.session_state.page_switch != st.session_state.current_page:
|
| 212 |
+
st.switch_page(st.session_state.page_switch)
|
| 213 |
+
|
| 214 |
+
# 页面标题
|
| 215 |
+
st.title("智能医导大模型")
|
| 216 |
+
|
| 217 |
+
# 说明
|
| 218 |
+
st.info(
|
| 219 |
+
"本项目是基于人工智能的文字、语音、视频生成领域搭建的智能医导大模型。用户被授予使用此工具创建文字、语音、视频的自由,但用户在使用过程中应该遵守当地法律,并负责任地使用。开发人员不对用户可能的不当使用承担任何责任。",
|
| 220 |
+
icon="❗",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# 初始化侧边栏
|
| 224 |
+
asr_text = init_sidebar()
|
| 225 |
+
|
| 226 |
+
# 初始化聊天历史记录
|
| 227 |
+
if "messages" not in st.session_state:
|
| 228 |
+
st.session_state.messages = []
|
| 229 |
+
|
| 230 |
+
message_col = None
|
| 231 |
+
if st.session_state.gen_digital_human_checkbox and WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
|
| 232 |
+
|
| 233 |
+
with st.container():
|
| 234 |
+
message_col, video_col = st.columns([0.6, 0.4])
|
| 235 |
+
|
| 236 |
+
with video_col:
|
| 237 |
+
# 创建 empty 控件
|
| 238 |
+
st.session_state.video_placeholder = st.empty()
|
| 239 |
+
with st.session_state.video_placeholder.container():
|
| 240 |
+
show_video(st.session_state.digital_human_video_path, autoplay=True, loop=True, muted=True)
|
| 241 |
+
|
| 242 |
+
with message_col:
|
| 243 |
+
init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
|
| 244 |
+
else:
|
| 245 |
+
init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
|
| 246 |
+
|
| 247 |
+
# 输入框显示提示信息
|
| 248 |
+
hint_msg = "你好,你可以向我提出任何关于就诊的问题,我将竭诚为您服务"
|
| 249 |
+
if st.session_state.button_msg != "x-x":
|
| 250 |
+
prompt = st.session_state.button_msg
|
| 251 |
+
st.session_state.button_msg = "x-x"
|
| 252 |
+
st.chat_input(hint_msg)
|
| 253 |
+
elif asr_text != "" and st.session_state.asr_text_cache != asr_text:
|
| 254 |
+
prompt = asr_text
|
| 255 |
+
st.chat_input(hint_msg)
|
| 256 |
+
st.session_state.asr_text_cache = asr_text
|
| 257 |
+
else:
|
| 258 |
+
prompt = st.chat_input(hint_msg)
|
| 259 |
+
|
| 260 |
+
# 接收用户输入
|
| 261 |
+
if prompt:
|
| 262 |
+
|
| 263 |
+
if message_col is None:
|
| 264 |
+
process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
|
| 265 |
+
else:
|
| 266 |
+
# 数字人启动,页面会分块,放入信息块中
|
| 267 |
+
with message_col:
|
| 268 |
+
process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# st.sidebar.page_link("app.py", label="商品页")
|
| 272 |
+
# st.sidebar.page_link("./pages/selling_page.py", label="主播卖货", disabled=True)
|
| 273 |
+
|
| 274 |
+
# META_INSTRUCTION
|
| 275 |
+
print("into sales page")
|
| 276 |
+
st.session_state.current_page = "pages/selling_page.py"
|
| 277 |
+
|
| 278 |
+
if "sales_info" not in st.session_state or st.session_state.sales_info == "":
|
| 279 |
+
st.session_state.page_switch = "app.py"
|
| 280 |
+
st.switch_page("app.py")
|
| 281 |
+
|
| 282 |
+
main((st.session_state.sales_info))
|