Spaces:
Runtime error
Runtime error
Upload 13 files
Browse files- app.py +63 -0
- audio_text.py +69 -0
- chat.py +27 -0
- general_chat.py +107 -0
- image_generation.py +192 -0
- model_config.py +77 -0
- note.py +9 -0
- process_image.py +28 -0
- reasoning_chat.py +209 -0
- requirements.txt +1 -0
- session_state.py +25 -0
- template.py +11 -0
- visual_chat.py +160 -0
app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
if "login_state" not in st.session_state:
|
| 4 |
+
st.session_state.login_state = False
|
| 5 |
+
if "api" not in st.session_state:
|
| 6 |
+
st.session_state.api = ""
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
if not st.session_state.login_state:
|
| 10 |
+
from note import regester, notes
|
| 11 |
+
|
| 12 |
+
st.subheader("Interact with AI models through SiliconFlow API key", anchor=False)
|
| 13 |
+
|
| 14 |
+
st.markdown(regester, unsafe_allow_html=True)
|
| 15 |
+
|
| 16 |
+
api_key = st.text_input("API KEY", st.session_state.api, key="api_key", type="password", placeholder="sk-...")
|
| 17 |
+
st.session_state.api = api_key
|
| 18 |
+
submit_btn = st.button("Submit", key="submit_btn", type="primary", disabled=not api_key)
|
| 19 |
+
|
| 20 |
+
st.markdown("---")
|
| 21 |
+
|
| 22 |
+
with st.container(border=True, key="note_container"):
|
| 23 |
+
st.markdown(notes, unsafe_allow_html=True)
|
| 24 |
+
|
| 25 |
+
if submit_btn and st.session_state.api:
|
| 26 |
+
st.session_state.login_state = True
|
| 27 |
+
st.rerun()
|
| 28 |
+
elif submit_btn and not st.session_state.api:
|
| 29 |
+
st.error("Please enter your SiliconFlow API key!")
|
| 30 |
+
else:
|
| 31 |
+
siliconflow()
|
| 32 |
+
|
| 33 |
+
def siliconflow():
|
| 34 |
+
function_list = ["General Chat", "Visual Chat", "Reasoning Chat", "Image Generation", "Audio to Text"]
|
| 35 |
+
function_item = st.sidebar.selectbox("Function", function_list, index=0, key="func_")
|
| 36 |
+
|
| 37 |
+
st.subheader(function_item, anchor=False)
|
| 38 |
+
|
| 39 |
+
if function_item == "General Chat":
|
| 40 |
+
from general_chat import generalChat
|
| 41 |
+
generalChat(api_key=st.session_state.api)
|
| 42 |
+
elif function_item == "Visual Chat":
|
| 43 |
+
from visual_chat import visualChat
|
| 44 |
+
visualChat(api_key=st.session_state.api)
|
| 45 |
+
elif function_item == "Reasoning Chat":
|
| 46 |
+
from reasoning_chat import reasoningChat
|
| 47 |
+
reasoningChat(api_key=st.session_state.api)
|
| 48 |
+
elif function_item == "Image Generation":
|
| 49 |
+
from image_generation import imageGeneration
|
| 50 |
+
imageGeneration(api_key=st.session_state.api)
|
| 51 |
+
elif function_item == "Audio to Text":
|
| 52 |
+
from audio_text import audioText
|
| 53 |
+
audioText(api_key=st.session_state.api)
|
| 54 |
+
|
| 55 |
+
st.sidebar.markdown("---")
|
| 56 |
+
|
| 57 |
+
if st.sidebar.button("Log Out", key="logout_btn"):
|
| 58 |
+
st.session_state.login_state = False
|
| 59 |
+
st.session_state.api = ""
|
| 60 |
+
st.rerun()
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
audio_text.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
def audio_transcription(api_key, audio_file):
|
| 5 |
+
api_key = api_key
|
| 6 |
+
base_url = "https://api.siliconflow.cn/v1"
|
| 7 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 8 |
+
|
| 9 |
+
transcription = client.audio.transcriptions.create(
|
| 10 |
+
model="FunAudioLLM/SenseVoiceSmall",
|
| 11 |
+
file=audio_file
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
return transcription
|
| 15 |
+
|
| 16 |
+
def audioText(api_key: str):
|
| 17 |
+
if "uploaded_audio" not in st.session_state:
|
| 18 |
+
st.session_state.uploaded_audio = None
|
| 19 |
+
if "input_audio" not in st.session_state:
|
| 20 |
+
st.session_state.input_audio = None
|
| 21 |
+
|
| 22 |
+
if st.session_state.uploaded_audio == None and st.session_state.input_audio == None:
|
| 23 |
+
disable = True
|
| 24 |
+
elif st.session_state.uploaded_audio != None and st.session_state.input_audio == None:
|
| 25 |
+
disable = False
|
| 26 |
+
elif st.session_state.uploaded_audio == None and st.session_state.input_audio != None:
|
| 27 |
+
disable = False
|
| 28 |
+
|
| 29 |
+
audio_uploader = st.file_uploader("Upload an Audio", type=["MP3", "WAV"], key="audio_uploader", disabled=st.session_state.input_audio!=None)
|
| 30 |
+
if audio_uploader is not None:
|
| 31 |
+
st.session_state.uploaded_audio = audio_uploader
|
| 32 |
+
elif audio_uploader is None:
|
| 33 |
+
st.session_state.uploaded_audio = None
|
| 34 |
+
audio_input = st.audio_input("Record an Audio", key="audio_input", disabled=st.session_state.uploaded_audio!=None)
|
| 35 |
+
if audio_input is not None:
|
| 36 |
+
st.session_state.input_audio = audio_input
|
| 37 |
+
elif audio_input is None:
|
| 38 |
+
st.session_state.input_audio = None
|
| 39 |
+
|
| 40 |
+
transcript_btn = st.button("Transcript", "transcript_btn", type="primary", disabled=disable)
|
| 41 |
+
|
| 42 |
+
transcription_str = ""
|
| 43 |
+
|
| 44 |
+
if transcript_btn:
|
| 45 |
+
if st.session_state.uploaded_audio is not None and st.session_state.input_audio is None:
|
| 46 |
+
try:
|
| 47 |
+
with st.spinner("Processing..."):
|
| 48 |
+
transcription = audio_transcription(api_key, st.session_state.uploaded_audio)
|
| 49 |
+
if transcription:
|
| 50 |
+
transcription_str = transcription.text
|
| 51 |
+
except Exception as e:
|
| 52 |
+
st.error(f"Error occured: {e}")
|
| 53 |
+
elif st.session_state.uploaded_audio is None and st.session_state.input_audio is not None:
|
| 54 |
+
try:
|
| 55 |
+
with st.spinner("Processing..."):
|
| 56 |
+
transcription = audio_transcription(api_key, st.session_state.input_audio)
|
| 57 |
+
if transcription:
|
| 58 |
+
transcription_str = transcription.text
|
| 59 |
+
except Exception as e:
|
| 60 |
+
st.error(f"Error occured: {e}")
|
| 61 |
+
elif st.session_state.uploaded_audio is None and st.session_state.input_audio is None:
|
| 62 |
+
st.info("Please upload an audio or record an audio!")
|
| 63 |
+
|
| 64 |
+
if transcription_str:
|
| 65 |
+
with st.container(border=True, key="trans_container"):
|
| 66 |
+
st.markdown(transcription_str)
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
audioText("sk-hvnjkojhpyitxfeqwdpynrsacitcvqffprrrzzgrwytpebrf")
|
chat.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
|
| 3 |
+
def chat_completion(
|
| 4 |
+
api_key: str,
|
| 5 |
+
model: str,
|
| 6 |
+
messages: list,
|
| 7 |
+
tokens: int,
|
| 8 |
+
temp: float,
|
| 9 |
+
topp: float,
|
| 10 |
+
freq: float,
|
| 11 |
+
pres: float,
|
| 12 |
+
stop: list):
|
| 13 |
+
api_key = api_key
|
| 14 |
+
base_url = "https://api.siliconflow.cn/v1"
|
| 15 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
| 16 |
+
response = client.chat.completions.create(
|
| 17 |
+
model=model,
|
| 18 |
+
messages=messages,
|
| 19 |
+
max_tokens=tokens,
|
| 20 |
+
temperature=temp,
|
| 21 |
+
top_p=topp,
|
| 22 |
+
frequency_penalty=freq,
|
| 23 |
+
presence_penalty=pres,
|
| 24 |
+
stop=stop,
|
| 25 |
+
stream=True
|
| 26 |
+
)
|
| 27 |
+
return response
|
general_chat.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from session_state import set_session_state
|
| 4 |
+
from chat import chat_completion
|
| 5 |
+
from template import general_default_prompt
|
| 6 |
+
from model_config import text_model
|
| 7 |
+
|
| 8 |
+
def generalChat(api_key: str):
|
| 9 |
+
set_session_state("general", general_default_prompt, 4096, 0.70)
|
| 10 |
+
|
| 11 |
+
if st.session_state.general_msg == []:
|
| 12 |
+
disable = True
|
| 13 |
+
elif st.session_state.general_msg != []:
|
| 14 |
+
disable = False
|
| 15 |
+
|
| 16 |
+
with st.sidebar:
|
| 17 |
+
clear_btn = st.button("Clear", "clear_", type="primary", use_container_width=True, disabled=disable)
|
| 18 |
+
undo_btn = st.button("Undo", "undo_", use_container_width=True, disabled=disable)
|
| 19 |
+
retry_btn = st.button("Retry", "retry_", use_container_width=True, disabled=disable)
|
| 20 |
+
|
| 21 |
+
model_list = text_model(api_key)
|
| 22 |
+
model = st.selectbox("Model", model_list, index=0, key="gen_model", disabled=not disable)
|
| 23 |
+
|
| 24 |
+
system_prompt = st.text_area("System Prompt", st.session_state.general_sys, key="gen_sys", disabled=not disable)
|
| 25 |
+
|
| 26 |
+
with st.expander("Advanced Setting"):
|
| 27 |
+
tokens: int = st.slider("Max Tokens", 1, 4096, st.session_state.general_tokens, 1, key="gen_tokens", disabled=not disable)
|
| 28 |
+
temp: float = st.slider("Temperature", 0.00, 2.00, st.session_state.general_temp, 0.01, key="gen_temp", disabled=not disable)
|
| 29 |
+
topp: float = st.slider("Top P", 0.01, 1.00, st.session_state.general_topp, 0.01, key="gen_topp", disabled=not disable)
|
| 30 |
+
freq: float = st.slider("Frequency Penalty", -2.00, 2.00, st.session_state.general_freq, 0.01, key="gen_freq", disabled=not disable)
|
| 31 |
+
pres: float = st.slider("Presence Penalty", -2.00, 2.00, st.session_state.general_pres, 0.01, key="gen_pres", disabled=not disable)
|
| 32 |
+
if st.toggle("Set stop", key="gen_stop_toggle", disabled=not disable):
|
| 33 |
+
st.session_state.general_stop = []
|
| 34 |
+
stop_str = st.text_input("Stop", st.session_state.general_stop_str, key="gen_stop_str", disabled=not disable)
|
| 35 |
+
st.session_state.general_stop_str = stop_str
|
| 36 |
+
submit_stop = st.button("Submit", "gen_submit_stop", disabled=not disable)
|
| 37 |
+
if submit_stop and stop_str:
|
| 38 |
+
st.session_state.general_stop.append(st.session_state.general_stop_str)
|
| 39 |
+
st.session_state.general_stop_str = ""
|
| 40 |
+
st.rerun()
|
| 41 |
+
if st.session_state.general_stop:
|
| 42 |
+
for stop_str in st.session_state.general_stop:
|
| 43 |
+
st.markdown(f"`{stop_str}`")
|
| 44 |
+
|
| 45 |
+
st.session_state.general_sys = system_prompt
|
| 46 |
+
st.session_state.general_tokens = tokens
|
| 47 |
+
st.session_state.general_temp = temp
|
| 48 |
+
st.session_state.general_topp = topp
|
| 49 |
+
st.session_state.general_freq = freq
|
| 50 |
+
st.session_state.general_pres = pres
|
| 51 |
+
|
| 52 |
+
for i in st.session_state.general_cache:
|
| 53 |
+
with st.chat_message(i["role"]):
|
| 54 |
+
st.markdown(i["content"])
|
| 55 |
+
|
| 56 |
+
if query := st.chat_input("Say something...", key="gen_query", disabled=model==""):
|
| 57 |
+
with st.chat_message("user"):
|
| 58 |
+
st.markdown(query)
|
| 59 |
+
st.session_state.general_msg.append({"role": "user", "content": query})
|
| 60 |
+
messages = [{"role": "system", "content": system_prompt}] + st.session_state.general_msg
|
| 61 |
+
with st.chat_message("assistant"):
|
| 62 |
+
try:
|
| 63 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.general_stop)
|
| 64 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 65 |
+
st.session_state.general_msg.append({"role": "assistant", "content": result})
|
| 66 |
+
except Exception as e:
|
| 67 |
+
st.error(f"Error occured: {e}")
|
| 68 |
+
st.session_state.general_cache = st.session_state.general_msg
|
| 69 |
+
st.rerun()
|
| 70 |
+
|
| 71 |
+
if clear_btn:
|
| 72 |
+
st.session_state.general_sys = general_default_prompt
|
| 73 |
+
st.session_state.general_tokens = 4096
|
| 74 |
+
st.session_state.general_temp = 0.70
|
| 75 |
+
st.session_state.general_topp = 0.70
|
| 76 |
+
st.session_state.general_freq = 0.00
|
| 77 |
+
st.session_state.general_pres = 0.00
|
| 78 |
+
st.session_state.general_stop = None
|
| 79 |
+
st.session_state.general_msg = []
|
| 80 |
+
st.session_state.general_cache = []
|
| 81 |
+
st.rerun()
|
| 82 |
+
|
| 83 |
+
if undo_btn:
|
| 84 |
+
del st.session_state.general_msg[-1]
|
| 85 |
+
del st.session_state.general_cache[-1]
|
| 86 |
+
st.rerun()
|
| 87 |
+
|
| 88 |
+
if retry_btn:
|
| 89 |
+
st.session_state.general_msg.pop()
|
| 90 |
+
st.session_state.general_cache = []
|
| 91 |
+
st.session_state.general_retry = True
|
| 92 |
+
st.rerun()
|
| 93 |
+
if st.session_state.general_retry:
|
| 94 |
+
for i in st.session_state.general_msg:
|
| 95 |
+
with st.chat_message(i["role"]):
|
| 96 |
+
st.markdown(i["content"])
|
| 97 |
+
messages = [{"role": "system", "content": system_prompt}] + st.session_state.general_msg
|
| 98 |
+
with st.chat_message("assistant"):
|
| 99 |
+
try:
|
| 100 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.general_stop)
|
| 101 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 102 |
+
st.session_state.general_msg.append({"role": "assistant", "content": result})
|
| 103 |
+
except Exception as e:
|
| 104 |
+
st.error(f"Error occured: {e}")
|
| 105 |
+
st.session_state.general_cache = st.session_state.general_msg
|
| 106 |
+
st.session_state.general_retry = False
|
| 107 |
+
st.rerun()
|
image_generation.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
+
from model_config import image_model_list
|
| 5 |
+
|
| 6 |
+
url = "https://api.siliconflow.cn/v1/images/generations"
|
| 7 |
+
|
| 8 |
+
flux_image_size = [
|
| 9 |
+
"1024x1024",
|
| 10 |
+
"960x1280",
|
| 11 |
+
"768x1024",
|
| 12 |
+
"720x1440",
|
| 13 |
+
"720x1280",
|
| 14 |
+
"others"
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
sd_image_size = [
|
| 18 |
+
"1024x1024",
|
| 19 |
+
"512x1024",
|
| 20 |
+
"768x512",
|
| 21 |
+
"768x512",
|
| 22 |
+
"1024x576",
|
| 23 |
+
"576x1024"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
def flux_image_generator(api_key: str, prompt: str, image_size: str, seed: int, step: int, prompt_enhancement: bool):
|
| 27 |
+
if seed is not None:
|
| 28 |
+
payload = {
|
| 29 |
+
"model": "black-forest-labs/FLUX.1-dev",
|
| 30 |
+
"prompt": prompt,
|
| 31 |
+
"image_size": image_size,
|
| 32 |
+
"seed": seed,
|
| 33 |
+
"num_inference_steps": step,
|
| 34 |
+
"prompt_enhancement": prompt_enhancement
|
| 35 |
+
}
|
| 36 |
+
elif seed is None:
|
| 37 |
+
payload = {
|
| 38 |
+
"model": "black-forest-labs/FLUX.1-dev",
|
| 39 |
+
"prompt": prompt,
|
| 40 |
+
"image_size": image_size,
|
| 41 |
+
"num_inference_steps": step,
|
| 42 |
+
"prompt_enhancement": prompt_enhancement
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
headers = {
|
| 46 |
+
"Authorization": f"Bearer {api_key}",
|
| 47 |
+
"Content-Type": "application/json"
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
response = requests.request("POST", url, json=payload, headers=headers)
|
| 51 |
+
if response.status_code == 200:
|
| 52 |
+
response_object = response.json()
|
| 53 |
+
response_data = response_object["images"]
|
| 54 |
+
response_url = response_data[0]["url"]
|
| 55 |
+
return response_url
|
| 56 |
+
|
| 57 |
+
def sd_image_generator(api_key: str, prompt: str, negative_prompt: str, image_size: str, seed: int, step: int, guidance_scale: int, prompt_enhancement: bool):
|
| 58 |
+
if seed is not None:
|
| 59 |
+
payload = {
|
| 60 |
+
"model": "stabilityai/stable-diffusion-3-5-large",
|
| 61 |
+
"prompt": prompt,
|
| 62 |
+
"negative_prompt": negative_prompt,
|
| 63 |
+
"image_size": image_size,
|
| 64 |
+
"batch_size": 1,
|
| 65 |
+
"seed": seed,
|
| 66 |
+
"num_inference_steps": step,
|
| 67 |
+
"guidance_scale": guidance_scale,
|
| 68 |
+
"prompt_enhancement": prompt_enhancement
|
| 69 |
+
}
|
| 70 |
+
else:
|
| 71 |
+
payload = {
|
| 72 |
+
"model": "stabilityai/stable-diffusion-3-5-large",
|
| 73 |
+
"prompt": prompt,
|
| 74 |
+
"negative_prompt": negative_prompt,
|
| 75 |
+
"image_size": image_size,
|
| 76 |
+
"batch_size": 1,
|
| 77 |
+
"num_inference_steps": step,
|
| 78 |
+
"guidance_scale": guidance_scale,
|
| 79 |
+
"prompt_enhancement": prompt_enhancement
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
headers = {
|
| 83 |
+
"Authorization": f"Bearer {api_key}",
|
| 84 |
+
"Content-Type": "application/json"
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
response = requests.request("POST", url, json=payload, headers=headers)
|
| 88 |
+
if response.status_code == 200:
|
| 89 |
+
response_object = response.json()
|
| 90 |
+
response_data = response_object["images"]
|
| 91 |
+
response_url = response_data[0]["url"]
|
| 92 |
+
return response_url
|
| 93 |
+
|
| 94 |
+
def imageGeneration(api_key: str):
|
| 95 |
+
if "image_url" not in st.session_state:
|
| 96 |
+
st.session_state.image_url = ""
|
| 97 |
+
if "generate_state" not in st.session_state:
|
| 98 |
+
st.session_state.generate_state = False
|
| 99 |
+
|
| 100 |
+
with st.sidebar:
|
| 101 |
+
reset_btn = st.button("Reset", "img_reset_btn", type="primary", use_container_width=True, disabled=st.session_state.image_url=="")
|
| 102 |
+
model_list = image_model_list
|
| 103 |
+
model = st.selectbox("Model", model_list, 0, key="img_model", disabled=st.session_state.image_url!="")
|
| 104 |
+
|
| 105 |
+
if model == "black-forest-labs/FLUX.1-dev":
|
| 106 |
+
with st.sidebar:
|
| 107 |
+
image_size = st.selectbox("Image Size", flux_image_size, 0, key="flux_img_size")
|
| 108 |
+
if image_size == "others":
|
| 109 |
+
length = st.text_input("Length", "", key="length")
|
| 110 |
+
width = st.text_input("Width", "", key="width")
|
| 111 |
+
if length and width and "." not in length and "." not in width:
|
| 112 |
+
if length[0]!="0" and width[0]!="0":
|
| 113 |
+
image_size = f"{length}x{width}"
|
| 114 |
+
st.session_state.generate_state = False
|
| 115 |
+
st.markdown(f"Custom Image Size: `{image_size}`")
|
| 116 |
+
elif length[0]=="0" or width[0]=="0":
|
| 117 |
+
st.session_state.generate_state = True
|
| 118 |
+
st.warning("Please input an integer!")
|
| 119 |
+
elif "." in length or "." in width:
|
| 120 |
+
st.session_state.generate_state = True
|
| 121 |
+
st.warning("Please input an integer!")
|
| 122 |
+
|
| 123 |
+
step = st.slider("Inference Steps", 1, 50, 50, 1, key="flux_step")
|
| 124 |
+
seed_input = st.text_input("Seed", "", key="flux_seed")
|
| 125 |
+
if seed_input and "." not in seed_input:
|
| 126 |
+
try:
|
| 127 |
+
seed = int(seed_input)
|
| 128 |
+
st.session_state.generate_state = False
|
| 129 |
+
except Exception as e:
|
| 130 |
+
st.session_state.generate_state = True
|
| 131 |
+
st.error(f"Error occured: {e}")
|
| 132 |
+
elif seed_input and "." in seed_input:
|
| 133 |
+
st.session_state.generate_state = True
|
| 134 |
+
st.warning("Please input an integer!")
|
| 135 |
+
elif not seed_input:
|
| 136 |
+
st.session_state.generate_state = False
|
| 137 |
+
seed = None
|
| 138 |
+
prompt_enhancement = st.toggle("Prompt Enhancement", False, key="flux_enhancement")
|
| 139 |
+
|
| 140 |
+
prompt = st.text_area("Prompt", "", key="flux_prompt", disabled=st.session_state.generate_state)
|
| 141 |
+
generate_btn = st.button("Generate", "flux_generate", type="primary", disabled=prompt=="")
|
| 142 |
+
|
| 143 |
+
if generate_btn:
|
| 144 |
+
try:
|
| 145 |
+
with st.spinner("Generating..."):
|
| 146 |
+
st.session_state.image_url = flux_image_generator(api_key, prompt, image_size, seed, step, prompt_enhancement)
|
| 147 |
+
st.rerun()
|
| 148 |
+
except Exception as e:
|
| 149 |
+
st.error(f"Error occured: {e}")
|
| 150 |
+
|
| 151 |
+
if st.session_state.image_url != "":
|
| 152 |
+
st.image(st.session_state.image_url, output_format="PNG")
|
| 153 |
+
|
| 154 |
+
elif model == "stabilityai/stable-diffusion-3-5-large":
|
| 155 |
+
with st.sidebar:
|
| 156 |
+
image_size = st.selectbox("Image Size", sd_image_size, 0, key="sd_img_size")
|
| 157 |
+
step = st.slider("Inference Steps", 1, 50, 50, 1, key="sd_step")
|
| 158 |
+
guidance_scale = st.slider("Guidance Scale", 0.0, 20.0, 4.5, 0.1, key="sd_guidance")
|
| 159 |
+
seed_input = st.text_input("Seed", "", key="sd_seed")
|
| 160 |
+
if seed_input and "." not in seed_input:
|
| 161 |
+
try:
|
| 162 |
+
seed = int(seed_input)
|
| 163 |
+
st.session_state.generate_state = False
|
| 164 |
+
except Exception as e:
|
| 165 |
+
st.session_state.generate_state = True
|
| 166 |
+
st.error(f"Error occured: {e}")
|
| 167 |
+
elif seed_input and "." in seed_input:
|
| 168 |
+
st.session_state.generate_state = True
|
| 169 |
+
st.warning("Please input an integer!")
|
| 170 |
+
elif not seed_input:
|
| 171 |
+
st.session_state.generate_state = False
|
| 172 |
+
seed = None
|
| 173 |
+
prompt_enhancement = st.toggle("Prompt Enhancement", False, key="sd_enhancement")
|
| 174 |
+
|
| 175 |
+
prompt = st.text_area("Prompt", "", key="sd_prompt", disabled=st.session_state.generate_state)
|
| 176 |
+
negative_prompt = st.text_area("Negative Prompt", "", key="negative_prompt", disabled=st.session_state.generate_state)
|
| 177 |
+
generate_btn = st.button("Generate", "sd_generate", type="primary", disabled=prompt=="")
|
| 178 |
+
|
| 179 |
+
if generate_btn:
|
| 180 |
+
try:
|
| 181 |
+
with st.spinner("Generating..."):
|
| 182 |
+
st.session_state.image_url = sd_image_generator(api_key, prompt, negative_prompt, image_size, seed, step, guidance_scale, prompt_enhancement)
|
| 183 |
+
st.rerun()
|
| 184 |
+
except Exception as e:
|
| 185 |
+
st.error(f"Error occured: {e}")
|
| 186 |
+
|
| 187 |
+
if st.session_state.image_url != "":
|
| 188 |
+
st.image(st.session_state.image_url, output_format="PNG")
|
| 189 |
+
|
| 190 |
+
if reset_btn:
|
| 191 |
+
st.session_state.image_url = ""
|
| 192 |
+
st.rerun()
|
model_config.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
url = "https://api.siliconflow.cn/v1/models"
|
| 5 |
+
|
| 6 |
+
visual_model_list = [
|
| 7 |
+
"Qwen/Qwen2-VL-72B-Instruct",
|
| 8 |
+
"OpenGVLab/InternVL2-26B",
|
| 9 |
+
"TeleAI/TeleMM",
|
| 10 |
+
"Pro/Qwen/Qwen2-VL-7B-Instruct",
|
| 11 |
+
"Pro/OpenGVLab/InternVL2-8B"
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
reasoning_model_list = [
|
| 15 |
+
"Qwen/QwQ-32B-Preview",
|
| 16 |
+
"Qwen/QVQ-72B-Preview",
|
| 17 |
+
"AIDC-AI/Marco-o1"
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
excluded_models = [
|
| 21 |
+
"deepseek-ai/deepseek-vl2",
|
| 22 |
+
"01-ai/Yi-1.5-6B-Chat"
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
image_model_list = [
|
| 26 |
+
"black-forest-labs/FLUX.1-dev",
|
| 27 |
+
"stabilityai/stable-diffusion-3-5-large"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
qwen_pattern = re.compile(r'^Qwen/')
|
| 31 |
+
meta_llama_pattern = re.compile(r'^meta-llama/')
|
| 32 |
+
deepseek_ai_pattern = re.compile(r'^deepseek-ai/')
|
| 33 |
+
pro_lora_pattern = re.compile(r'^(Pro|LoRA)/')
|
| 34 |
+
|
| 35 |
+
def extract_version_and_params(model):
|
| 36 |
+
version_match = re.search(r'(\d+(\.\d+)+)', model)
|
| 37 |
+
version = float(version_match.group(1)) if version_match else 0.0
|
| 38 |
+
|
| 39 |
+
params_match = re.search(r'(\d+(\.\d+)?)(B|b)', model)
|
| 40 |
+
params = float(params_match.group(1)) if params_match else 0.0
|
| 41 |
+
|
| 42 |
+
return version, params
|
| 43 |
+
|
| 44 |
+
def sort_models(model_list):
|
| 45 |
+
return sorted(model_list, key=lambda x: extract_version_and_params(x), reverse=True)
|
| 46 |
+
|
| 47 |
+
def text_model(api_key: str) -> list:
|
| 48 |
+
model_list = []
|
| 49 |
+
|
| 50 |
+
querystring = {"type":"text","sub_type":"chat"}
|
| 51 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
| 52 |
+
response = requests.request("GET", url, params=querystring, headers=headers)
|
| 53 |
+
|
| 54 |
+
if response.status_code == 200:
|
| 55 |
+
response_object = response.json()
|
| 56 |
+
response_data = response_object["data"]
|
| 57 |
+
for i in response_data:
|
| 58 |
+
if i["id"] not in visual_model_list and i["id"] not in reasoning_model_list and i["id"] not in excluded_models:
|
| 59 |
+
model_list.append(i["id"])
|
| 60 |
+
|
| 61 |
+
qwen_models = [model for model in model_list if qwen_pattern.search(model) and not pro_lora_pattern.search(model)]
|
| 62 |
+
meta_llama_models = [model for model in model_list if meta_llama_pattern.search(model) and not pro_lora_pattern.search(model)]
|
| 63 |
+
deepseek_ai_models = [model for model in model_list if deepseek_ai_pattern.search(model) and not pro_lora_pattern.search(model)]
|
| 64 |
+
other_models = [model for model in model_list if not qwen_pattern.search(model) and not meta_llama_pattern.search(model) and not deepseek_ai_pattern.search(model) and not pro_lora_pattern.search(model)]
|
| 65 |
+
pro_lora_models = [model for model in model_list if pro_lora_pattern.search(model)]
|
| 66 |
+
|
| 67 |
+
qwen_models_sorted = sort_models(qwen_models)
|
| 68 |
+
meta_llama_models_sorted = sort_models(meta_llama_models)
|
| 69 |
+
deepseek_ai_models_sorted = sort_models(deepseek_ai_models)
|
| 70 |
+
other_models_sorted = sort_models(other_models)
|
| 71 |
+
pro_lora_models_sorted = sort_models(pro_lora_models)
|
| 72 |
+
|
| 73 |
+
model_list = qwen_models_sorted + meta_llama_models_sorted + deepseek_ai_models_sorted + other_models_sorted + pro_lora_models_sorted
|
| 74 |
+
|
| 75 |
+
return model_list
|
| 76 |
+
|
| 77 |
+
|
note.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
regester = '<span style="color: grey;">No API key yet? Register here: </span><a href="https://cloud.siliconflow.cn/i/b7XJpKVo" style="color: #4682B4;" target="_blank">SiliconCloud</a>'
|
| 2 |
+
|
| 3 |
+
notes = """<strong style="color: red;">Please note:</strong>
|
| 4 |
+
|
| 5 |
+
1. This site does not retain chat records. Please copy and save any important chat data before refreshing or leaving the site.
|
| 6 |
+
2. We recommend not sharing personal sensitive information, such as phone numbers, emails, or home addresses, during conversations.
|
| 7 |
+
3. This site only supports open-source models officially deployed by SiliconFlow.
|
| 8 |
+
4. This site is a Playground Demo and is not affiliated with SiliconFlow.
|
| 9 |
+
"""
|
process_image.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import io
|
| 3 |
+
import base64
|
| 4 |
+
|
| 5 |
+
def image_processor(uploaded_image):
|
| 6 |
+
img = Image.open(uploaded_image)
|
| 7 |
+
|
| 8 |
+
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
|
| 9 |
+
background = Image.new('RGB', img.size, (255, 255, 255))
|
| 10 |
+
background.paste(img, mask=img.split()[-1])
|
| 11 |
+
img = background
|
| 12 |
+
|
| 13 |
+
max_size = 1024
|
| 14 |
+
if max(img.size) > max_size:
|
| 15 |
+
ratio = max_size / max(img.size)
|
| 16 |
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
| 17 |
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
| 18 |
+
|
| 19 |
+
output_buffer = io.BytesIO()
|
| 20 |
+
img.save(output_buffer, format='PNG', quality=95)
|
| 21 |
+
while output_buffer.tell() > 1024 * 1024:
|
| 22 |
+
quality = int(95 * (1024 * 1024 / output_buffer.tell()))
|
| 23 |
+
output_buffer = io.BytesIO()
|
| 24 |
+
img.save(output_buffer, format='PNG', quality=quality)
|
| 25 |
+
|
| 26 |
+
base64_encoded = base64.b64encode(output_buffer.getvalue()).decode('utf-8')
|
| 27 |
+
|
| 28 |
+
return base64_encoded
|
reasoning_chat.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from session_state import set_session_state
|
| 4 |
+
from chat import chat_completion
|
| 5 |
+
from template import qwen_reasoning_prompt, marco_reasoning_prompt
|
| 6 |
+
from model_config import reasoning_model_list
|
| 7 |
+
|
| 8 |
+
def reasoningChat(api_key: str):
|
| 9 |
+
set_session_state("reasoning", "", 8192, 0.50)
|
| 10 |
+
|
| 11 |
+
if st.session_state.reasoning_msg == []:
|
| 12 |
+
disable = True
|
| 13 |
+
elif st.session_state.reasoning_msg != []:
|
| 14 |
+
disable = False
|
| 15 |
+
|
| 16 |
+
with st.sidebar:
|
| 17 |
+
clear_btn = st.button("Clear", "re_clear", type="primary", use_container_width=True, disabled=disable)
|
| 18 |
+
undo_btn = st.button("Undo", "re_undo", use_container_width=True, disabled=disable)
|
| 19 |
+
retry_btn = st.button("Retry", "re_retry", use_container_width=True, disabled=disable)
|
| 20 |
+
|
| 21 |
+
model_list = reasoning_model_list
|
| 22 |
+
model = st.selectbox("Model", model_list, 0, key="reason_model", disabled=not disable)
|
| 23 |
+
st.session_state.reasoning_model = model
|
| 24 |
+
|
| 25 |
+
if model == "AIDC-AI/Marco-o1":
|
| 26 |
+
st.session_state.reasoning_sys = marco_reasoning_prompt
|
| 27 |
+
else:
|
| 28 |
+
st.session_state.reasoning_sys = qwen_reasoning_prompt
|
| 29 |
+
|
| 30 |
+
with st.expander("Advanced Setting"):
|
| 31 |
+
tokens = st.slider("Max Tokens", 1, 8192, st.session_state.reasoning_tokens, 1, key="re_tokens", disabled=not disable)
|
| 32 |
+
temp = st.slider("Temperature", 0.00, 2.00, st.session_state.reasoning_temp, 0.01, key="re_temp", disabled=not disable)
|
| 33 |
+
topp = st.slider("Top P", 0.01, 1.00, st.session_state.reasoning_topp, 0.01, key="re_topp", disabled=not disable)
|
| 34 |
+
freq = st.slider("Frequency Penalty", -2.00, 2.00, st.session_state.reasoning_freq, 0.01, key="re_freq", disabled=not disable)
|
| 35 |
+
pres = st.slider("Presence Penalty", -2.00, 2.00, st.session_state.reasoning_pres, 0.01, key="re_pres", disabled=not disable)
|
| 36 |
+
if st.toggle("Set stop", key="re_stop_toggle", disabled=not disable):
|
| 37 |
+
st.session_state.reasoning_stop = []
|
| 38 |
+
stop_str = st.text_input("Stop", st.session_state.reasoning_stop_str, key="re_stop_str", disabled=not disable)
|
| 39 |
+
st.session_state.visual_stop_str = stop_str
|
| 40 |
+
submit_stop = st.button("Submit", "re_submit_stop", disabled=not disable)
|
| 41 |
+
if submit_stop and stop_str:
|
| 42 |
+
st.session_state.reasoning_stop.append(st.session_state.reasoning_stop_str)
|
| 43 |
+
st.session_state.reasoning_stop_str = ""
|
| 44 |
+
st.rerun()
|
| 45 |
+
if st.session_state.reasoning_stop:
|
| 46 |
+
for stop_str in st.session_state.reasoning_stop:
|
| 47 |
+
st.markdown(f"`{stop_str}`")
|
| 48 |
+
|
| 49 |
+
st.session_state.reasoning_tokens = tokens
|
| 50 |
+
st.session_state.reasoning_temp = temp
|
| 51 |
+
st.session_state.reasoning_topp = topp
|
| 52 |
+
st.session_state.reasoning_freq = freq
|
| 53 |
+
st.session_state.reasoning_pres = pres
|
| 54 |
+
|
| 55 |
+
if st.session_state.reasoning_model == "Qwen/QVQ-72B-Preview":
|
| 56 |
+
from process_image import image_processor
|
| 57 |
+
image_type = ["PNG", "JPG", "JPEG"]
|
| 58 |
+
uploaded_image: list = st.file_uploader("Upload an image", type=image_type, accept_multiple_files=True, key="re_uploaded_image")
|
| 59 |
+
base64_image_list = []
|
| 60 |
+
if uploaded_image is not None:
|
| 61 |
+
with st.expander("Image"):
|
| 62 |
+
for i in uploaded_image:
|
| 63 |
+
st.image(uploaded_image, output_format="PNG")
|
| 64 |
+
base64_image_list.append(image_processor(i))
|
| 65 |
+
|
| 66 |
+
for i in st.session_state.reasoning_cache:
|
| 67 |
+
with st.chat_message(i["role"]):
|
| 68 |
+
st.markdown(i["content"])
|
| 69 |
+
|
| 70 |
+
if query := st.chat_input("Say something...", key="re_qvq_query", disabled=base64_image_list==[]):
|
| 71 |
+
with st.chat_message("user"):
|
| 72 |
+
st.markdown(query)
|
| 73 |
+
|
| 74 |
+
st.session_state.reasoning_msg.append({"role": "user", "content": query})
|
| 75 |
+
|
| 76 |
+
if len(st.session_state.reasoning_msg) == 1:
|
| 77 |
+
messages = [
|
| 78 |
+
{"role": "system", "content": st.session_state.reasoning_sys},
|
| 79 |
+
{"role": "user", "content": []}
|
| 80 |
+
]
|
| 81 |
+
for base64_img in base64_image_list:
|
| 82 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 83 |
+
messages[1]["content"].append(img_url_obj)
|
| 84 |
+
messages[1]["content"].append({"type": "text", "text": query})
|
| 85 |
+
elif len(st.session_state.reasoning_msg) > 1:
|
| 86 |
+
messages = [
|
| 87 |
+
{"role": "system", "content": st.session_state.reasoning_sys},
|
| 88 |
+
{"role": "user", "content": []}
|
| 89 |
+
]
|
| 90 |
+
for base64_img in base64_image_list:
|
| 91 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 92 |
+
messages[1]["content"].append(img_url_obj)
|
| 93 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.reasoning_msg[0]["content"]})
|
| 94 |
+
messages += st.session_state.reasoning_msg[1:]
|
| 95 |
+
|
| 96 |
+
with st.chat_message("assistant"):
|
| 97 |
+
try:
|
| 98 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.reasoning_stop)
|
| 99 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 100 |
+
st.session_state.reasoning_msg.append({"role": "assistant", "content": result})
|
| 101 |
+
except Exception as e:
|
| 102 |
+
st.error(f"Error occured: {e}")
|
| 103 |
+
|
| 104 |
+
st.session_state.reasoning_cache = st.session_state.reasoning_msg
|
| 105 |
+
st.rerun()
|
| 106 |
+
|
| 107 |
+
if retry_btn:
|
| 108 |
+
st.session_state.reasoning_msg.pop()
|
| 109 |
+
st.session_state.reasoning_cache = []
|
| 110 |
+
st.session_state.reasoning_retry = True
|
| 111 |
+
st.rerun()
|
| 112 |
+
if st.session_state.reasoning_retry:
|
| 113 |
+
for i in st.session_state.reasoning_msg:
|
| 114 |
+
with st.chat_message(i["role"]):
|
| 115 |
+
st.markdown(i["content"])
|
| 116 |
+
if len(st.session_state.reasoning_msg) == 1:
|
| 117 |
+
messages = [
|
| 118 |
+
{"role": "system", "content": st.session_state.reasoning_sys},
|
| 119 |
+
{"role": "user", "content": []}
|
| 120 |
+
]
|
| 121 |
+
for base64_img in base64_image_list:
|
| 122 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 123 |
+
messages[1]["content"].append(img_url_obj)
|
| 124 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.reasoning_msg[0]["content"]})
|
| 125 |
+
elif len(st.session_state.reasoning_msg) > 1:
|
| 126 |
+
messages = [
|
| 127 |
+
{"role": "system", "content": st.session_state.reasoning_sys},
|
| 128 |
+
{"role": "user", "content": []}
|
| 129 |
+
]
|
| 130 |
+
for base64_img in base64_image_list:
|
| 131 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 132 |
+
messages[1]["content"].append(img_url_obj)
|
| 133 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.reasoning_msg[0]["content"]})
|
| 134 |
+
messages += st.session_state.reasoning_msg[1:]
|
| 135 |
+
|
| 136 |
+
with st.chat_message("assistant"):
|
| 137 |
+
try:
|
| 138 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.reasoning_stop)
|
| 139 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 140 |
+
st.session_state.reasoning_msg.append({"role": "assistant", "content": result})
|
| 141 |
+
except Exception as e:
|
| 142 |
+
st.error(f"Error occured: {e}")
|
| 143 |
+
|
| 144 |
+
st.session_state.reasoning_cache = st.session_state.reasoning_msg
|
| 145 |
+
st.session_state.reasoning_retry = False
|
| 146 |
+
st.rerun()
|
| 147 |
+
else:
|
| 148 |
+
for i in st.session_state.reasoning_cache:
|
| 149 |
+
with st.chat_message(i["role"]):
|
| 150 |
+
st.markdown(i["content"])
|
| 151 |
+
|
| 152 |
+
if query := st.chat_input("Say something...", key="re_query", disabled=model==""):
|
| 153 |
+
with st.chat_message("user"):
|
| 154 |
+
st.markdown(query)
|
| 155 |
+
|
| 156 |
+
st.session_state.reasoning_msg.append({"role": "user", "content": query})
|
| 157 |
+
|
| 158 |
+
messages = [{"role": "system", "content": st.session_state.reasoning_sys}] + st.session_state.reasoning_msg
|
| 159 |
+
|
| 160 |
+
with st.chat_message("assistant"):
|
| 161 |
+
try:
|
| 162 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.reasoning_stop)
|
| 163 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 164 |
+
st.session_state.reasoning_msg.append({"role": "assistant", "content": result})
|
| 165 |
+
except Exception as e:
|
| 166 |
+
st.error(f"Error occured: {e}")
|
| 167 |
+
|
| 168 |
+
st.session_state.reasoning_cache = st.session_state.reasoning_msg
|
| 169 |
+
st.rerun()
|
| 170 |
+
|
| 171 |
+
if retry_btn:
|
| 172 |
+
st.session_state.reasoning_msg.pop()
|
| 173 |
+
st.session_state.reasoning_cache = []
|
| 174 |
+
st.session_state.reasoning_retry = True
|
| 175 |
+
st.rerun()
|
| 176 |
+
if st.session_state.reasoning_retry:
|
| 177 |
+
for i in st.session_state.reasoning_msg:
|
| 178 |
+
with st.chat_message(i["role"]):
|
| 179 |
+
st.markdown(i["content"])
|
| 180 |
+
|
| 181 |
+
messages = [{"role": "system", "content": st.session_state.reasoning_sys}] + st.session_state.reasoning_msg
|
| 182 |
+
|
| 183 |
+
with st.chat_message("assistant"):
|
| 184 |
+
try:
|
| 185 |
+
response = chat_completion(api_key, model, messages, tokens, temp, topp, freq, pres, st.session_state.reasoning_stop)
|
| 186 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 187 |
+
st.session_state.reasoning_msg.append({"role": "assistant", "content": result})
|
| 188 |
+
except Exception as e:
|
| 189 |
+
st.error(f"Error occured: {e}")
|
| 190 |
+
|
| 191 |
+
st.session_state.reasoning_cache = st.session_state.reasoning_msg
|
| 192 |
+
st.session_state.reasoning_retry = False
|
| 193 |
+
st.rerun()
|
| 194 |
+
|
| 195 |
+
if clear_btn:
|
| 196 |
+
st.session_state.reasoning_tokens = 8192
|
| 197 |
+
st.session_state.reasoning_temp = 0.50
|
| 198 |
+
st.session_state.reasoning_topp = 0.70
|
| 199 |
+
st.session_state.reasoning_freq = 0.00
|
| 200 |
+
st.session_state.reasoning_pres = 0.00
|
| 201 |
+
st.session_state.reasoning_msg = []
|
| 202 |
+
st.session_state.reasoning_cache = []
|
| 203 |
+
st.session_state.reasoning_stop = None
|
| 204 |
+
st.rerun()
|
| 205 |
+
|
| 206 |
+
if undo_btn:
|
| 207 |
+
del st.session_state.reasoning_msg[-1]
|
| 208 |
+
del st.session_state.reasoning_cache[-1]
|
| 209 |
+
st.rerun()
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
openai
|
session_state.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
def set_session_state(mode: str, sys: str, tokens: int, temp: float):
|
| 4 |
+
if f"{mode}_sys" not in st.session_state:
|
| 5 |
+
st.session_state[f"{mode}_sys"] = sys
|
| 6 |
+
if f"{mode}_msg" not in st.session_state:
|
| 7 |
+
st.session_state[f"{mode}_msg"] = []
|
| 8 |
+
if f"{mode}_cache" not in st.session_state:
|
| 9 |
+
st.session_state[f"{mode}_cache"] = []
|
| 10 |
+
if f"{mode}_tokens" not in st.session_state:
|
| 11 |
+
st.session_state[f"{mode}_tokens"] = tokens
|
| 12 |
+
if f"{mode}_temp" not in st.session_state:
|
| 13 |
+
st.session_state[f"{mode}_temp"] = temp
|
| 14 |
+
if f"{mode}_topp" not in st.session_state:
|
| 15 |
+
st.session_state[f"{mode}_topp"] = 0.70
|
| 16 |
+
if f"{mode}_freq" not in st.session_state:
|
| 17 |
+
st.session_state[f"{mode}_freq"] = 0.00
|
| 18 |
+
if f"{mode}_pres" not in st.session_state:
|
| 19 |
+
st.session_state[f"{mode}_pres"] = 0.00
|
| 20 |
+
if f"{mode}_stop" not in st.session_state:
|
| 21 |
+
st.session_state[f"{mode}_stop"] = None
|
| 22 |
+
if f"{mode}_stop_str" not in st.session_state:
|
| 23 |
+
st.session_state[f"{mode}_stop_str"] = ""
|
| 24 |
+
if f"{mode}_retry" not in st.session_state:
|
| 25 |
+
st.session_state[f"{mode}_retry"] = False
|
template.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
general_default_prompt = "You are a helpful assistant."
|
| 2 |
+
|
| 3 |
+
visual_default_prompt = "Answer questions or perform tasks based on the image uploaded by the user."
|
| 4 |
+
|
| 5 |
+
qwen_reasoning_prompt = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
|
| 6 |
+
|
| 7 |
+
marco_reasoning_prompt = """你是一个经过良好训练的 AI 助手,你的名字是 Marco-o1,由阿里国际数字商业集团的 AI Business 创造。
|
| 8 |
+
|
| 9 |
+
## 重要!!!!!
|
| 10 |
+
当你回答问题时,你的思考应该在 <Thought> 内完成,<Output> 内输出你的结果。<Thought> 应该尽可能是英文,但是有 2 个特例,一个是对原文中的引用,另一个是是数学应该使用 markdown 格式,<Output> 内的输出需要遵循用户输入的语言。
|
| 11 |
+
"""
|
visual_chat.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from session_state import set_session_state
|
| 4 |
+
from chat import chat_completion
|
| 5 |
+
from template import visual_default_prompt
|
| 6 |
+
from model_config import visual_model_list
|
| 7 |
+
|
| 8 |
+
def visualChat(api_key: str):
|
| 9 |
+
set_session_state("visual", visual_default_prompt, 4096, 0.50)
|
| 10 |
+
|
| 11 |
+
if st.session_state.visual_msg == []:
|
| 12 |
+
disable = True
|
| 13 |
+
elif st.session_state.visual_msg != []:
|
| 14 |
+
disable = False
|
| 15 |
+
|
| 16 |
+
with st.sidebar:
|
| 17 |
+
clear_btn = st.button("Clear", "vi_clear", type="primary", use_container_width=True, disabled=disable)
|
| 18 |
+
undo_btn = st.button("Undo", "vi_undo", use_container_width=True, disabled=disable)
|
| 19 |
+
retry_btn = st.button("Retry", "vi_retry", use_container_width=True, disabled=disable)
|
| 20 |
+
|
| 21 |
+
model_list = visual_model_list
|
| 22 |
+
model = st.selectbox("Model", model_list, 0, key="vi_model", disabled=not disable)
|
| 23 |
+
|
| 24 |
+
system_prompt = st.text_area("System Prompt", st.session_state.visual_sys, key="vi_sys", disabled=not disable)
|
| 25 |
+
|
| 26 |
+
with st.expander("Advanced Setting"):
|
| 27 |
+
tokens = st.slider("Max Tokens", 1, 4096, st.session_state.visual_tokens, 1, key="vi_tokens", disabled=not disable)
|
| 28 |
+
temp = st.slider("Temperature", 0.00, 2.00, st.session_state.visual_temp, 0.01, key="vi_temp", disabled=not disable)
|
| 29 |
+
topp = st.slider("Top P", 0.01, 1.00, st.session_state.visual_topp, 0.01, key="vi_topp", disabled=not disable)
|
| 30 |
+
freq = st.slider("Frequency Penalty", -2.00, 2.00, st.session_state.visual_freq, 0.01, key="vi_freq", disabled=not disable)
|
| 31 |
+
pres = st.slider("Presence Penalty", -2.00, 2.00, st.session_state.visual_pres, 0.01, key="vi_pres", disabled=not disable)
|
| 32 |
+
if st.toggle("Set stop", key="vi_stop_toggle", disabled=not disable):
|
| 33 |
+
st.session_state.general_stop = []
|
| 34 |
+
stop_str = st.text_input("Stop", st.session_state.visual_stop_str, key="vi_stop_str", disabled=not disable)
|
| 35 |
+
st.session_state.visual_stop_str = stop_str
|
| 36 |
+
submit_stop = st.button("Submit", "vi_submit_stop", disabled=not disable)
|
| 37 |
+
if submit_stop and stop_str:
|
| 38 |
+
st.session_state.visual_stop.append(st.session_state.visual_stop_str)
|
| 39 |
+
st.session_state.visual_stop_str = ""
|
| 40 |
+
st.rerun()
|
| 41 |
+
if st.session_state.visual_stop:
|
| 42 |
+
for stop_str in st.session_state.visual_stop:
|
| 43 |
+
st.markdown(f"`{stop_str}`")
|
| 44 |
+
|
| 45 |
+
st.session_state.visual_sys = system_prompt
|
| 46 |
+
st.session_state.visual_tokens = tokens
|
| 47 |
+
st.session_state.visual_temp = temp
|
| 48 |
+
st.session_state.visual_topp = topp
|
| 49 |
+
st.session_state.visual_freq = freq
|
| 50 |
+
st.session_state.visual_pres = pres
|
| 51 |
+
|
| 52 |
+
image_type = ["PNG", "JPG", "JPEG"]
|
| 53 |
+
uploaded_image: list = st.file_uploader("Upload an image", type=image_type, accept_multiple_files=True, key="uploaded_image", disabled=not disable)
|
| 54 |
+
base64_image_list = []
|
| 55 |
+
if uploaded_image is not None:
|
| 56 |
+
from process_image import image_processor
|
| 57 |
+
with st.expander("Image"):
|
| 58 |
+
for i in uploaded_image:
|
| 59 |
+
st.image(uploaded_image, output_format="PNG")
|
| 60 |
+
base64_image_list.append(image_processor(i))
|
| 61 |
+
|
| 62 |
+
for i in st.session_state.visual_cache:
|
| 63 |
+
with st.chat_message(i["role"]):
|
| 64 |
+
st.markdown(i["content"])
|
| 65 |
+
|
| 66 |
+
if query := st.chat_input("Say something...", key="vi_query", disabled=base64_image_list==[]):
|
| 67 |
+
with st.chat_message("user"):
|
| 68 |
+
st.markdown(query)
|
| 69 |
+
|
| 70 |
+
st.session_state.visual_msg.append({"role": "user", "content": query})
|
| 71 |
+
|
| 72 |
+
if len(st.session_state.visual_msg) == 1:
|
| 73 |
+
messages = [
|
| 74 |
+
{"role": "system", "content": system_prompt},
|
| 75 |
+
{"role": "user", "content": []}
|
| 76 |
+
]
|
| 77 |
+
for base64_img in base64_image_list:
|
| 78 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 79 |
+
messages[1]["content"].append(img_url_obj)
|
| 80 |
+
messages[1]["content"].append({"type": "text", "text": query})
|
| 81 |
+
elif len(st.session_state.visual_msg) > 1:
|
| 82 |
+
messages = [
|
| 83 |
+
{"role": "system", "content": system_prompt},
|
| 84 |
+
{"role": "user", "content": []}
|
| 85 |
+
]
|
| 86 |
+
for base64_img in base64_image_list:
|
| 87 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 88 |
+
messages[1]["content"].append(img_url_obj)
|
| 89 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
|
| 90 |
+
messages += st.session_state.visual_msg[1:]
|
| 91 |
+
|
| 92 |
+
with st.chat_message("assistant"):
|
| 93 |
+
try:
|
| 94 |
+
response = chat_completion(api_key, model, messages, tokens, temp, freq, pres, st.session_state.visual_stop)
|
| 95 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 96 |
+
st.session_state.general_msg.append({"role": "assistant", "content": result})
|
| 97 |
+
except Exception as e:
|
| 98 |
+
st.error(f"Error occured: {e}")
|
| 99 |
+
|
| 100 |
+
st.session_state.visual_cache = st.session_state.visual_msg
|
| 101 |
+
st.rerun()
|
| 102 |
+
|
| 103 |
+
if clear_btn:
|
| 104 |
+
st.session_state.visual_sys = visual_default_prompt
|
| 105 |
+
st.session_state.visual_tokens = 4096
|
| 106 |
+
st.session_state.visual_temp = 0.50
|
| 107 |
+
st.session_state.visual_topp = 0.70
|
| 108 |
+
st.session_state.visual_freq = 0.00
|
| 109 |
+
st.session_state.visual_pres = 0.00
|
| 110 |
+
st.session_state.visual_msg = []
|
| 111 |
+
st.session_state.visual_cache = []
|
| 112 |
+
st.session_state.visual_stop = None
|
| 113 |
+
st.rerun()
|
| 114 |
+
|
| 115 |
+
if undo_btn:
|
| 116 |
+
del st.session_state.visual_msg[-1]
|
| 117 |
+
del st.session_state.visual_cache[-1]
|
| 118 |
+
st.rerun()
|
| 119 |
+
|
| 120 |
+
if retry_btn:
|
| 121 |
+
st.session_state.visual_msg.pop()
|
| 122 |
+
st.session_state.visual_cache = []
|
| 123 |
+
st.session_state.visual_retry = True
|
| 124 |
+
st.rerun()
|
| 125 |
+
if st.session_state.visual_retry:
|
| 126 |
+
for i in st.session_state.visual_msg:
|
| 127 |
+
with st.chat_message(i["role"]):
|
| 128 |
+
st.markdown(i["content"])
|
| 129 |
+
|
| 130 |
+
if len(st.session_state.visual_msg) == 1:
|
| 131 |
+
messages = [
|
| 132 |
+
{"role": "system", "content": system_prompt},
|
| 133 |
+
{"role": "user", "content": []}
|
| 134 |
+
]
|
| 135 |
+
for base64_img in base64_image_list:
|
| 136 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 137 |
+
messages[1]["content"].append(img_url_obj)
|
| 138 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
|
| 139 |
+
elif len(st.session_state.visual_msg) > 1:
|
| 140 |
+
messages = [
|
| 141 |
+
{"role": "system", "content": system_prompt},
|
| 142 |
+
{"role": "user", "content": []}
|
| 143 |
+
]
|
| 144 |
+
for base64_img in base64_image_list:
|
| 145 |
+
img_url_obj = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}", "detail": "high"}}
|
| 146 |
+
messages[1]["content"].append(img_url_obj)
|
| 147 |
+
messages[1]["content"].append({"type": "text", "text": st.session_state.visual_msg[0]["content"]})
|
| 148 |
+
messages += st.session_state.visual_msg[1:]
|
| 149 |
+
|
| 150 |
+
with st.chat_message("assistant"):
|
| 151 |
+
try:
|
| 152 |
+
response = chat_completion(api_key, model, messages, tokens, temp, freq, pres, st.session_state.visual_stop)
|
| 153 |
+
result = st.write_stream(chunk.choices[0].delta.content for chunk in response if chunk.choices[0].delta.content is not None)
|
| 154 |
+
st.session_state.general_msg.append({"role": "assistant", "content": result})
|
| 155 |
+
except Exception as e:
|
| 156 |
+
st.error(f"Error occured: {e}")
|
| 157 |
+
|
| 158 |
+
st.session_state.visual_cache = st.session_state.visual_msg
|
| 159 |
+
st.session_state.visual_retry = False
|
| 160 |
+
st.rerun()
|