Sandeep Chowdhary commited on
Commit
f2690f9
·
verified ·
1 Parent(s): 50f0d1d

Update model card

Browse files
Files changed (1) hide show
  1. README.md +152 -42
README.md CHANGED
@@ -50,45 +50,166 @@ Note: Label order in predictions matches the order above.
50
  ## Usage
51
 
52
  ```python
53
- import torch
54
  from transformers import DistilBertTokenizer
55
- import tempfile
56
  from huggingface_hub import snapshot_download
57
 
58
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
 
60
- # Download and load model
61
- model_link = "sanchow/veganism_and_vegetarianism-distilbert-classifier"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  with tempfile.TemporaryDirectory() as temp_dir:
63
- snapshot_download(repo_id=model_link, local_dir=temp_dir, local_dir_use_symlinks=False)
64
-
65
- sys.path.insert(0, temp_dir)
66
- from model_class import MultilabelClassifier
67
-
68
- tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
69
- checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu')
70
- model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
71
- model.load_state_dict(checkpoint['model_state_dict'])
72
- model.to(device)
73
- model.eval()
74
-
75
- # Predict
76
- text = "Your text here"
77
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
78
- with torch.no_grad():
79
- predictions = model(**inputs).cpu().numpy()
80
-
81
- # Get scores
82
- label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ```
84
 
85
- ## Applications
86
-
87
- - Content analysis of social media discussions
88
- - Research on public sentiment and discourse
89
- - Policy analysis of key topics and concerns
90
- - Market research on trends and interests
91
-
92
 
93
  ## Performance
94
 
@@ -104,17 +225,6 @@ Dataset: ~900 GPT-labeled samples per sector (600 train, 150 validation, 150 tes
104
 
105
  ## Optimal Thresholds
106
 
107
- Use these thresholds for best performance:
108
-
109
- - Animal Welfare: 0.481
110
- - Environmental Impact: 0.459
111
- - Health: 0.201
112
- - Lab Grown And Alt Proteins: 0.341
113
- - Psychology And Identity: 0.525
114
- - Systemic Vs Individual Action: 0.375
115
- - Taste And Convenience: 0.664
116
-
117
- Usage:
118
  ```python
119
  optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
120
  for label, score in zip(label_names, predictions[0]):
 
50
  ## Usage
51
 
52
  ```python
53
+ import torch, sys, os, tempfile
54
  from transformers import DistilBertTokenizer
 
55
  from huggingface_hub import snapshot_download
56
 
57
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
 
59
+ def print_sorted_label_scores(label_scores):
60
+ # Sort label_scores dict by score descending
61
+ sorted_items = sorted(label_scores.items(), key=lambda x: x[1], reverse=True)
62
+ for label, score in sorted_items:
63
+ print(f" {label}: {score:.6f}")
64
+
65
+ # ------------------ TRANSPORT ------------------
66
+ transport_model_link = 'sanchow/electric_vehicles-distilbert-classifier'
67
+ transport_examples = [
68
+ "Switching to electric cars can cut down on smog and carbon output."
69
+ ]
70
+
71
+ print(f"\n{'='*60}")
72
+ print("MODEL: TRANSPORT SECTOR")
73
+ print(f"{'='*60}")
74
+ print(f"Downloading model: {transport_model_link}")
75
  with tempfile.TemporaryDirectory() as temp_dir:
76
+ snapshot_download(
77
+ repo_id=transport_model_link,
78
+ local_dir=temp_dir,
79
+ local_dir_use_symlinks=False
80
+ )
81
+ model_class_path = os.path.join(temp_dir, 'model_class.py')
82
+ if not os.path.exists(model_class_path):
83
+ print(f"model_class.py not found in downloaded files")
84
+ print(f" Available files: {os.listdir(temp_dir)}")
85
+ else:
86
+ sys.path.insert(0, temp_dir)
87
+ from model_class import MultilabelClassifier
88
+ tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
89
+ checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu', weights_only=False)
90
+ model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
91
+ model.load_state_dict(checkpoint['model_state_dict'])
92
+ model.to(device)
93
+ model.eval()
94
+ print("Model loaded successfully")
95
+ print(f" Labels: {checkpoint['label_names']}")
96
+ print("\nTransport classifier results for transport_examples:\n")
97
+ for i, test_text in enumerate(transport_examples):
98
+ inputs = tokenizer(
99
+ test_text,
100
+ return_tensors="pt",
101
+ truncation=True,
102
+ max_length=512,
103
+ padding=True
104
+ ).to(device)
105
+ with torch.no_grad():
106
+ outputs = model(**inputs)
107
+ predictions = outputs.cpu().numpy() if isinstance(outputs, (tuple, list)) else outputs.cpu().numpy()
108
+ label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
109
+ print(f"Example {i+1}: '{test_text}'")
110
+ print("Predictions (all label scores, highest first):")
111
+ print_sorted_label_scores(label_scores)
112
+ print("-" * 40)
113
+
114
+ # ------------------ HOUSING ------------------
115
+ housing_model_link = 'sanchow/solar_energy-distilbert-classifier'
116
+ housing_examples = [
117
+ "Solar panels on rooftops can significantly reduce electricity bills."
118
+ ]
119
+
120
+ print(f"\n{'='*60}")
121
+ print("MODEL: HOUSING SECTOR")
122
+ print(f"{'='*60}")
123
+ print(f"Downloading model: {housing_model_link}")
124
+ with tempfile.TemporaryDirectory() as temp_dir:
125
+ snapshot_download(
126
+ repo_id=housing_model_link,
127
+ local_dir=temp_dir,
128
+ local_dir_use_symlinks=False
129
+ )
130
+ model_class_path = os.path.join(temp_dir, 'model_class.py')
131
+ if not os.path.exists(model_class_path):
132
+ print(f"model_class.py not found in downloaded files")
133
+ print(f" Available files: {os.listdir(temp_dir)}")
134
+ else:
135
+ sys.path.insert(0, temp_dir)
136
+ from model_class import MultilabelClassifier
137
+ tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
138
+ checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu', weights_only=False)
139
+ model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
140
+ model.load_state_dict(checkpoint['model_state_dict'])
141
+ model.to(device)
142
+ model.eval()
143
+ print("Model loaded successfully")
144
+ print(f" Labels: {checkpoint['label_names']}")
145
+ print("\nHousing classifier results for housing_examples:\n")
146
+ for i, test_text in enumerate(housing_examples):
147
+ inputs = tokenizer(
148
+ test_text,
149
+ return_tensors="pt",
150
+ truncation=True,
151
+ max_length=512,
152
+ padding=True
153
+ ).to(device)
154
+ with torch.no_grad():
155
+ outputs = model(**inputs)
156
+ predictions = outputs.cpu().numpy() if isinstance(outputs, (tuple, list)) else outputs.cpu().numpy()
157
+ label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
158
+ print(f"Example {i+1}: '{test_text}'")
159
+ print("Predictions (all label scores, highest first):")
160
+ print_sorted_label_scores(label_scores)
161
+ print("-" * 40)
162
+
163
+ # ------------------ FOOD ------------------
164
+ food_model_link = 'sanchow/veganism_and_vegetarianism-distilbert-classifier'
165
+ food_examples = [
166
+ "Plant-based diets can help reduce environmental impact of food production."
167
+ ]
168
+
169
+ print(f"\n{'='*60}")
170
+ print("MODEL: FOOD SECTOR")
171
+ print(f"{'='*60}")
172
+ print(f"Downloading model: {food_model_link}")
173
+ with tempfile.TemporaryDirectory() as temp_dir:
174
+ snapshot_download(
175
+ repo_id=food_model_link,
176
+ local_dir=temp_dir,
177
+ local_dir_use_symlinks=False
178
+ )
179
+ model_class_path = os.path.join(temp_dir, 'model_class.py')
180
+ if not os.path.exists(model_class_path):
181
+ print(f"model_class.py not found in downloaded files")
182
+ print(f" Available files: {os.listdir(temp_dir)}")
183
+ else:
184
+ sys.path.insert(0, temp_dir)
185
+ from model_class import MultilabelClassifier
186
+ tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
187
+ checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu', weights_only=False)
188
+ model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
189
+ model.load_state_dict(checkpoint['model_state_dict'])
190
+ model.to(device)
191
+ model.eval()
192
+ print("Model loaded successfully")
193
+ print(f" Labels: {checkpoint['label_names']}")
194
+ print("\nFood classifier results for food_examples:\n")
195
+ for i, test_text in enumerate(food_examples):
196
+ inputs = tokenizer(
197
+ test_text,
198
+ return_tensors="pt",
199
+ truncation=True,
200
+ max_length=512,
201
+ padding=True
202
+ ).to(device)
203
+ with torch.no_grad():
204
+ outputs = model(**inputs)
205
+ predictions = outputs.cpu().numpy() if isinstance(outputs, (tuple, list)) else outputs.cpu().numpy()
206
+ label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
207
+ print(f"Example {i+1}: '{test_text}'")
208
+ print("Predictions (all label scores, highest first):")
209
+ print_sorted_label_scores(label_scores)
210
+ print("-" * 40)
211
  ```
212
 
 
 
 
 
 
 
 
213
 
214
  ## Performance
215
 
 
225
 
226
  ## Optimal Thresholds
227
 
 
 
 
 
 
 
 
 
 
 
 
228
  ```python
229
  optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
230
  for label, score in zip(label_names, predictions[0]):