Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.header("🤗 Instruction Tuned SmolLM 360M") | |
| model_path = "Sharathhebbar24/smollm_sft_360M_instruct_tuned_v2" | |
| model = AutoModelForCausalLM.from_pretrained(model_path).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| if message["role"] != "system": | |
| with st.chat_message(message['role']): | |
| if message['role'] == "assistant": | |
| st.json(message['content']) | |
| else: | |
| st.markdown(message["content"]) | |
| if user_input := st.chat_input("Your answer.", max_chars=1000): | |
| st.session_state.messages.append({ | |
| "role": "user", | |
| "content": user_input | |
| }) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.chat_message("assistant"): | |
| prompt = f'''### Instruction:\nExtract action, date, time, attendees, location, duration, recurrence, and notes from the dataset.\n\n### Input: \n{user_input}\n\n### Response:''' | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_new_tokens=100, | |
| do_sample=False, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| decoded_output = tokenizer.decode(outputs[0]) | |
| generated_response = decoded_output.split("### Response:")[-1].strip() | |
| generated_response = generated_response[:generated_response.find("}") + 1] | |
| generated_response = generated_response.replace("None", "null") | |
| generated_response = generated_response.replace("'", '"') | |
| st.json(generated_response) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": generated_response | |
| }) |