mayankmvp commited on
Commit
8a16d8e
·
verified ·
1 Parent(s): 0032849

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import torch
6
+ from sklearn.model_selection import train_test_split
7
+ from torch.utils.data import Dataset
8
+ from transformers import (
9
+ T5ForConditionalGeneration,
10
+ T5TokenizerFast,
11
+ DataCollatorForSeq2Seq,
12
+ Trainer,
13
+ TrainingArguments,
14
+ pipeline
15
+ )
16
+
17
+ DATA_PATH = "data/train.csv"
18
+ DEFAULT_INPUT_COL = "text"
19
+ DEFAULT_TARGET_COL = "label"
20
+
21
+ class CSVDataset(Dataset):
22
+ def __init__(self, df, tokenizer, input_col, target_col, max_input_len=512, max_target_len=128, prefix="summarize: "):
23
+ self.inputs = df[input_col].astype(str).tolist()
24
+ self.targets = df[target_col].astype(str).tolist()
25
+ self.tokenizer = tokenizer
26
+ self.max_input_len = max_input_len
27
+ self.max_target_len = max_target_len
28
+ self.prefix = prefix
29
+
30
+ def __len__(self):
31
+ return len(self.inputs)
32
+
33
+ def __getitem__(self, idx):
34
+ src = self.prefix + self.inputs[idx]
35
+ tgt = self.targets[idx]
36
+ model_inputs = self.tokenizer(
37
+ src, max_length=self.max_input_len, truncation=True, padding=False, return_tensors=None
38
+ )
39
+ with self.tokenizer.as_target_tokenizer():
40
+ labels = self.tokenizer(
41
+ tgt, max_length=self.max_target_len, truncation=True, padding=False, return_tensors=None
42
+ )
43
+ model_inputs["labels"] = labels["input_ids"]
44
+ return model_inputs
45
+
46
+ def run_training(base_model, epochs, batch_size, lr, warmup_steps, weight_decay, max_input_len, max_target_len, input_col, target_col, eval_ratio, grad_accum, fp16):
47
+ log_lines = []
48
+ def log(msg):
49
+ log_lines.append(msg)
50
+
51
+ if not os.path.exists(DATA_PATH):
52
+ return "data/train.csv not found.", ""
53
+
54
+ try:
55
+ df = pd.read_csv(DATA_PATH)
56
+ except Exception as e:
57
+ return f"Failed reading CSV: {e}", ""
58
+
59
+ for c in [input_col, target_col]:
60
+ if c not in df.columns:
61
+ return f"Column '{c}' not in CSV. Found: {list(df.columns)}", ""
62
+
63
+ log("Loading tokenizer & model...")
64
+ tok = T5TokenizerFast.from_pretrained(base_model)
65
+ mdl = T5ForConditionalGeneration.from_pretrained(base_model)
66
+
67
+ train_df, val_df = train_test_split(df, test_size=float(eval_ratio), random_state=42)
68
+ train_ds = CSVDataset(train_df, tok, input_col, target_col, max_input_len, max_target_len)
69
+ val_ds = CSVDataset(val_df, tok, input_col, target_col, max_input_len, max_target_len)
70
+
71
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tok, model=mdl)
72
+ output_dir = "checkpoint"
73
+
74
+ training_args = TrainingArguments(
75
+ output_dir=output_dir,
76
+ num_train_epochs=int(epochs),
77
+ per_device_train_batch_size=int(batch_size),
78
+ per_device_eval_batch_size=int(batch_size),
79
+ learning_rate=float(lr),
80
+ weight_decay=float(weight_decay),
81
+ warmup_steps=int(warmup_steps),
82
+ predict_with_generate=True,
83
+ evaluation_strategy="epoch",
84
+ save_strategy="epoch",
85
+ logging_steps=10,
86
+ load_best_model_at_end=True,
87
+ metric_for_best_model="eval_loss",
88
+ gradient_accumulation_steps=int(grad_accum),
89
+ fp16=bool(fp16),
90
+ report_to=[],
91
+ )
92
+
93
+ trainer = Trainer(
94
+ model=mdl,
95
+ args=training_args,
96
+ train_dataset=train_ds,
97
+ eval_dataset=val_ds,
98
+ tokenizer=tok,
99
+ data_collator=data_collator
100
+ )
101
+
102
+ log("Starting training...")
103
+ trainer.train()
104
+
105
+ log("Saving model...")
106
+ trainer.save_model(output_dir)
107
+ tok.save_pretrained(output_dir)
108
+
109
+ return "\n".join(log_lines), "Training complete. Model saved to ./checkpoint"
110
+
111
+ def make_pipe_from_checkpoint():
112
+ if not os.path.exists("checkpoint"):
113
+ raise RuntimeError("No checkpoint found. Train first.")
114
+ return pipeline("text2text-generation", model="checkpoint")
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.Markdown("# 🔧 Train & Share: Summarizer (FLAN‑T5)")
118
+ with gr.Tab("Train"):
119
+ gr.Markdown("Use defaults and click **Start Training**. This runs inside the Space.")
120
+ base_model = gr.Dropdown(choices=["google/flan-t5-small","google/flan-t5-base"], value="google/flan-t5-small", label="Base model")
121
+ epochs = gr.Slider(1, 6, value=2, step=1, label="Epochs")
122
+ batch_size = gr.Slider(2, 16, value=8, step=1, label="Batch size")
123
+ lr = gr.Textbox(value="5e-5", label="Learning rate")
124
+ warmup = gr.Textbox(value="100", label="Warmup steps")
125
+ wd = gr.Textbox(value="0.01", label="Weight decay")
126
+ max_in = gr.Slider(128, 1024, value=512, step=32, label="Max input length")
127
+ max_out = gr.Slider(32, 256, value=128, step=8, label="Max target length")
128
+ in_col = gr.Textbox(value=DEFAULT_INPUT_COL, label="Input column")
129
+ out_col = gr.Textbox(value=DEFAULT_TARGET_COL, label="Target column")
130
+ eval_ratio = gr.Textbox(value="0.1", label="Eval ratio (0-1)")
131
+ grad_accum = gr.Slider(1, 8, value=1, step=1, label="Gradient accumulation")
132
+ use_fp16 = gr.Checkbox(value=True, label="Use fp16 (GPU only)")
133
+ train_btn = gr.Button("🚀 Start Training")
134
+ train_log = gr.Textbox(label="Training log", lines=10)
135
+ train_status = gr.Textbox(label="Status")
136
+
137
+ def train_click(bm, e, bs, lrn, wu, wdec, mi, mo, ic, oc, er, ga, fp):
138
+ log, status = run_training(bm, e, bs, lrn, wu, wdec, mi, mo, ic, oc, er, ga, fp)
139
+ return log, status
140
+
141
+ train_btn.click(train_click, [base_model, epochs, batch_size, lr, warmup, wd, max_in, max_out, in_col, out_col, eval_ratio, grad_accum, use_fp16], [train_log, train_status])
142
+
143
+ with gr.Tab("Demo"):
144
+ gr.Markdown("After training, this tab uses the local **checkpoint**.")
145
+ inp = gr.Textbox(label="Input Text", lines=10, placeholder="Paste text here...")
146
+ max_new_tokens = gr.Slider(16, 256, value=128, step=8, label="Max new tokens")
147
+ temperature = gr.Slider(0, 1.0, value=0.0, step=0.1, label="Temperature")
148
+ topp = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
149
+ btn = gr.Button("Summarize")
150
+ out = gr.Textbox(label="Summary", lines=10)
151
+
152
+ pipe_holder = {"pipe": None}
153
+
154
+ def summarize_click(text, max_new_tokens, temperature, top_p):
155
+ if pipe_holder["pipe"] is None:
156
+ pipe_holder["pipe"] = make_pipe_from_checkpoint()
157
+ gen = pipe_holder["pipe"](
158
+ f"summarize: {text}",
159
+ max_new_tokens=int(max_new_tokens),
160
+ do_sample=float(temperature)>0,
161
+ temperature=float(temperature),
162
+ top_p=float(top_p)
163
+ )
164
+ return gen[0]["generated_text"]
165
+
166
+ btn.click(summarize_click, [inp, max_new_tokens, temperature, topp], [out])
167
+
168
+ if __name__ == "__main__":
169
+ demo.launch()