phxdev commited on
Commit
30370bb
Β·
verified Β·
1 Parent(s): 659646c
Files changed (1) hide show
  1. src/streamlit_app.py +321 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,323 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
+ import torch
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import io
7
+ import base64
8
+ import textwrap
9
+ import urllib.parse
10
 
11
+ @st.cache_resource
12
+ def load_model():
13
+ """Load and cache the model"""
14
+ base_model = "mistralai/Mistral-7B-Instruct-v0.2"
15
+ adapter_model = "phxdev/corporate-synergy-bot-7b"
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ base_model,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto"
22
+ )
23
+ model = PeftModel.from_pretrained(model, adapter_model)
24
+ return tokenizer, model
25
+
26
+ tokenizer, model = load_model()
27
+
28
+ def transform_text(text, mode="To Corporate", domain="general", seniority="mid"):
29
+ """Transform text between casual and corporate speak"""
30
+
31
+ if mode == "To Corporate":
32
+ instruction = f"Transform to {domain} corporate speak (seniority: {seniority})"
33
+ else:
34
+ instruction = "Translate corporate speak to plain English"
35
+
36
+ prompt = f"""### Instruction: {instruction}
37
+ ### Input: {text}
38
+ ### Response:"""
39
+
40
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
41
+
42
+ with torch.no_grad():
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_new_tokens=150,
46
+ temperature=0.7,
47
+ top_p=0.9,
48
+ do_sample=True,
49
+ pad_token_id=tokenizer.eos_token_id
50
+ )
51
+
52
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return response.split("### Response:")[-1].strip()
54
+
55
+ def create_shareable_card(input_text, output_text, mode, domain, seniority):
56
+ """Create a shareable image card with the transformation"""
57
+ # Card dimensions
58
+ width, height = 800, 600
59
+
60
+ # Colors
61
+ bg_color = "#f8f9fa"
62
+ primary_color = "#2c3e50"
63
+ secondary_color = "#3498db"
64
+ text_color = "#2c3e50"
65
+
66
+ # Create image
67
+ img = Image.new('RGB', (width, height), bg_color)
68
+ draw = ImageDraw.Draw(img)
69
+
70
+ try:
71
+ # Try to load a nicer font
72
+ title_font = ImageFont.truetype("arial.ttf", 28)
73
+ subtitle_font = ImageFont.truetype("arial.ttf", 18)
74
+ text_font = ImageFont.truetype("arial.ttf", 16)
75
+ small_font = ImageFont.truetype("arial.ttf", 14)
76
+ except:
77
+ # Fallback to default font
78
+ title_font = ImageFont.load_default()
79
+ subtitle_font = ImageFont.load_default()
80
+ text_font = ImageFont.load_default()
81
+ small_font = ImageFont.load_default()
82
+
83
+ # Title
84
+ title = "🏒 Corporate Synergy Bot"
85
+ title_bbox = draw.textbbox((0, 0), title, font=title_font)
86
+ title_width = title_bbox[2] - title_bbox[0]
87
+ draw.text(((width - title_width) // 2, 30), title, fill=primary_color, font=title_font)
88
+
89
+ # Mode indicator
90
+ mode_text = f"{mode} β€’ {domain.title()} β€’ {seniority.title()}"
91
+ mode_bbox = draw.textbbox((0, 0), mode_text, font=subtitle_font)
92
+ mode_width = mode_bbox[2] - mode_bbox[0]
93
+ draw.text(((width - mode_width) // 2, 80), mode_text, fill=secondary_color, font=subtitle_font)
94
+
95
+ # Input section
96
+ draw.text((50, 140), "INPUT:", fill=primary_color, font=subtitle_font)
97
+
98
+ # Wrap input text
99
+ input_lines = textwrap.wrap(input_text, width=70)
100
+ y_pos = 170
101
+ for line in input_lines[:4]: # Max 4 lines
102
+ draw.text((50, y_pos), line, fill=text_color, font=text_font)
103
+ y_pos += 25
104
+
105
+ # Arrow
106
+ arrow = "⬇️"
107
+ arrow_bbox = draw.textbbox((0, 0), arrow, font=title_font)
108
+ arrow_width = arrow_bbox[2] - arrow_bbox[0]
109
+ draw.text(((width - arrow_width) // 2, y_pos + 20), arrow, fill=secondary_color, font=title_font)
110
+
111
+ # Output section
112
+ y_pos += 80
113
+ draw.text((50, y_pos), "OUTPUT:", fill=primary_color, font=subtitle_font)
114
+
115
+ # Wrap output text
116
+ output_lines = textwrap.wrap(output_text, width=70)
117
+ y_pos += 30
118
+ for line in output_lines[:4]: # Max 4 lines
119
+ draw.text((50, y_pos), line, fill=text_color, font=text_font)
120
+ y_pos += 25
121
+
122
+ # Footer
123
+ footer_text = "Generated with Corporate Synergy Bot 7B"
124
+ footer_bbox = draw.textbbox((0, 0), footer_text, font=small_font)
125
+ footer_width = footer_bbox[2] - footer_bbox[0]
126
+ draw.text(((width - footer_width) // 2, height - 40), footer_text, fill=secondary_color, font=small_font)
127
+
128
+ return img
129
+
130
+ def generate_linkedin_post_text(input_text, output_text, mode, domain):
131
+ """Generate LinkedIn post text"""
132
+ if mode == "To Corporate":
133
+ post = f"""🏒 Transform your communication with AI!
134
+
135
+ From casual: "{input_text}"
136
+ To professional: "{output_text}"
137
+
138
+ #{domain}Career #ProfessionalCommunication #CorporateSpeak #AI #WorkplaceCommunication #ProfessionalDevelopment
139
+
140
+ Try it yourself with Corporate Synergy Bot 7B! πŸš€"""
141
+ else:
142
+ post = f"""πŸ’¬ Decode corporate jargon with AI!
143
+
144
+ From corporate: "{input_text}"
145
+ To plain English: "{output_text}"
146
+
147
+ #ClearCommunication #CorporateJargon #WorkplaceCommunication #AI #ProfessionalDevelopment #PlainEnglish
148
+
149
+ Cut through the corporate speak with Corporate Synergy Bot 7B! ✨"""
150
+
151
+ return post
152
+
153
+ # Streamlit interface
154
+ st.set_page_config(
155
+ page_title="Corporate Synergy Bot 7B",
156
+ page_icon="🏒",
157
+ layout="wide"
158
+ )
159
+
160
+ st.title("🏒 Corporate Synergy Bot 7B")
161
+ st.markdown("""
162
+ Transform casual language into professional corporate communication or decode corporate jargon back to plain English.
163
+
164
+ Powered by fine-tuned Mistral-7B with LoRA.
165
+ """)
166
+
167
+ col1, col2 = st.columns(2)
168
+
169
+ with col1:
170
+ st.subheader("Input")
171
+ input_text = st.text_area(
172
+ "Input Text",
173
+ placeholder="Enter text to transform...",
174
+ height=100,
175
+ key="input_text"
176
+ )
177
+
178
+ mode = st.radio(
179
+ "Transformation Mode",
180
+ ["To Corporate", "To Plain English"],
181
+ index=0,
182
+ key="mode"
183
+ )
184
+
185
+ col1a, col1b = st.columns(2)
186
+ with col1a:
187
+ domain = st.selectbox(
188
+ "Domain (for corporate mode)",
189
+ ["general", "tech", "finance", "consulting", "healthcare", "retail"],
190
+ index=0,
191
+ key="domain"
192
+ )
193
+
194
+ with col1b:
195
+ seniority = st.selectbox(
196
+ "Seniority Level",
197
+ ["junior", "mid", "senior", "executive"],
198
+ index=1,
199
+ key="seniority"
200
+ )
201
+
202
+ transform_btn = st.button("Transform", type="primary", use_container_width=True)
203
+
204
+ with col2:
205
+ st.subheader("Output")
206
+
207
+ if transform_btn and input_text:
208
+ with st.spinner("Transforming..."):
209
+ output_text = transform_text(input_text, mode, domain, seniority)
210
+ st.text_area(
211
+ "Transformed Text",
212
+ value=output_text,
213
+ height=100,
214
+ key="output_text"
215
+ )
216
+
217
+ # Store the output for card generation
218
+ st.session_state.last_output = output_text
219
+ st.session_state.last_input = input_text
220
+ st.session_state.last_mode = mode
221
+ st.session_state.last_domain = domain
222
+ st.session_state.last_seniority = seniority
223
+ elif "last_output" in st.session_state:
224
+ st.text_area(
225
+ "Transformed Text",
226
+ value=st.session_state.last_output,
227
+ height=100,
228
+ key="output_display"
229
+ )
230
+
231
+ # Social sharing section
232
+ if "last_output" in st.session_state and "last_input" in st.session_state:
233
+ st.subheader("πŸ“± Share Your Transformation")
234
+
235
+ col_share1, col_share2 = st.columns(2)
236
+
237
+ with col_share1:
238
+ if st.button("🎨 Generate Shareable Card", use_container_width=True):
239
+ with st.spinner("Creating your card..."):
240
+ card_img = create_shareable_card(
241
+ st.session_state.last_input,
242
+ st.session_state.last_output,
243
+ st.session_state.last_mode,
244
+ st.session_state.last_domain,
245
+ st.session_state.last_seniority
246
+ )
247
+
248
+ # Convert to bytes for download
249
+ img_buffer = io.BytesIO()
250
+ card_img.save(img_buffer, format='PNG')
251
+ img_buffer.seek(0)
252
+
253
+ # Display the card
254
+ st.image(card_img, caption="Your shareable card", use_column_width=True)
255
+
256
+ # Download button
257
+ st.download_button(
258
+ label="πŸ’Ύ Download Card",
259
+ data=img_buffer.getvalue(),
260
+ file_name="corporate_synergy_transformation.png",
261
+ mime="image/png",
262
+ use_container_width=True
263
+ )
264
+
265
+ with col_share2:
266
+ if st.button("πŸ“ Generate LinkedIn Post", use_container_width=True):
267
+ linkedin_text = generate_linkedin_post_text(
268
+ st.session_state.last_input,
269
+ st.session_state.last_output,
270
+ st.session_state.last_mode,
271
+ st.session_state.last_domain
272
+ )
273
+
274
+ st.text_area(
275
+ "LinkedIn Post Text",
276
+ value=linkedin_text,
277
+ height=200,
278
+ key="linkedin_post"
279
+ )
280
+
281
+ # LinkedIn share button
282
+ linkedin_url = f"https://www.linkedin.com/sharing/share-offsite/?url={urllib.parse.quote('https://huggingface.co/spaces/phxdev/corporate-synergy-bot')}&summary={urllib.parse.quote(linkedin_text[:200])}"
283
+
284
+ st.markdown(f"""
285
+ <a href="{linkedin_url}" target="_blank">
286
+ <button style="
287
+ background-color: #0077B5;
288
+ color: white;
289
+ padding: 10px 20px;
290
+ border: none;
291
+ border-radius: 5px;
292
+ cursor: pointer;
293
+ width: 100%;
294
+ font-size: 16px;
295
+ text-decoration: none;
296
+ ">
297
+ πŸ”— Share on LinkedIn
298
+ </button>
299
+ </a>
300
+ """, unsafe_allow_html=True)
301
+
302
+ # Examples section
303
+ st.subheader("Examples")
304
+ examples = [
305
+ ["let's meet tomorrow", "To Corporate", "general", "mid"],
306
+ ["I need help with this project", "To Corporate", "tech", "senior"],
307
+ ["good job on the presentation", "To Corporate", "consulting", "executive"],
308
+ ["We need to leverage our synergies to maximize stakeholder value", "To Plain English", "general", "mid"],
309
+ ["Let's circle back on the deliverables", "To Plain English", "general", "mid"],
310
+ ]
311
+
312
+ example_cols = st.columns(len(examples))
313
+ for i, (text, ex_mode, ex_domain, ex_seniority) in enumerate(examples):
314
+ with example_cols[i]:
315
+ if st.button(f"Example {i+1}", key=f"example_{i}"):
316
+ st.session_state.input_text = text
317
+ st.session_state.mode = ex_mode
318
+ st.session_state.domain = ex_domain
319
+ st.session_state.seniority = ex_seniority
320
+ st.rerun()
321
+
322
+ if __name__ == "__main__":
323
+ pass