Sandeep Chowdhary commited on
Commit
8f86a8d
·
verified ·
1 Parent(s): 7ccc19c

Update model card

Browse files
Files changed (1) hide show
  1. README.md +178 -19
README.md CHANGED
@@ -6,24 +6,59 @@ tags:
6
  - food
7
  - climate-change
8
  - sustainability
 
9
  license: mit
10
  ---
11
 
12
- # Veganism & Vegetarianism Classifier (DistilBERT)
13
 
14
- This model classifies content related to plant-based diets and sustainable food systems.
15
 
16
  ## Model Details
17
 
18
- - **Model Type**: DistilBERT
19
  - **Task**: Multilabel text classification
20
  - **Sector**: Veganism & Vegetarianism
 
21
  - **Number of Labels**: 7
22
  - **Training Data**: Reddit posts and comments from r/climate, r/climateactionplan, r/climatechange, r/climatechaos, r/climateco, r/climatecrisis, r/climatecrisiscanada, r/climatedisalarm, r/climatejobslist, r/climatejustice, r/climatememes, r/climatenews, r/climateoffensive, r/climatepolicy, r/climate_science
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ## Labels
25
 
26
- The model predicts **7 labels** simultaneously:
27
 
28
  1. **Animal Welfare**
29
  2. **Environmental Impact**
@@ -34,24 +69,38 @@ The model predicts **7 labels** simultaneously:
34
  7. **Taste And Convenience**
35
 
36
 
 
 
 
 
 
 
37
  ## Usage
38
 
39
  ```python
 
 
 
 
 
 
 
 
40
  import torch
41
  from transformers import DistilBertTokenizer
42
  import sys
43
  import os
44
 
45
- # Load the custom MultilabelClassifier model
46
- model_name = "sanchow/veganism_and_vegetarianism-distilbert-classifier"
47
  sys.path.append(model_name)
48
  from model_class import MultilabelClassifier
49
 
50
  # Load tokenizer
51
  tokenizer = DistilBertTokenizer.from_pretrained(model_name)
52
 
53
- # Load the model weights
54
- checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
 
55
  model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
56
  model.load_state_dict(checkpoint['model_state_dict'])
57
  model.eval()
@@ -60,18 +109,59 @@ model.eval()
60
  text = "Your text here"
61
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
62
  outputs = model(**inputs)
63
- predictions = outputs[0]
64
 
65
  # Get label predictions
 
66
  label_names = ['Animal Welfare', 'Environmental Impact', 'Health', 'Lab Grown And Alt Proteins', 'Psychology And Identity', 'Systemic Vs Individual Action', 'Taste And Convenience']
67
  for i, (label, score) in enumerate(zip(label_names, predictions[0])):
68
  if score > 0.5: # Threshold for positive classification
69
  print(f"{label}: {score:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
70
  ```
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  ## Optimal Thresholds
74
 
 
 
75
  - **Animal Welfare**: 0.481
76
  - **Environmental Impact**: 0.459
77
  - **Health**: 0.201
@@ -80,22 +170,91 @@ for i, (label, score) in enumerate(zip(label_names, predictions[0])):
80
  - **Systemic Vs Individual Action**: 0.375
81
  - **Taste And Convenience**: 0.664
82
 
 
 
 
 
83
 
84
- ## Training Methodology
 
 
 
 
 
85
 
86
- This model was trained using GPT-labeled data from Reddit discussions (2010-2023) with:
87
- - Regex filtering for sector-specific content
88
- - GPT-assisted multilabel classification
89
- - Optimal threshold tuning using Jaccard similarity optimization
90
 
91
- ## Applications
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- - Content analysis of social media discussions
94
- - Research on public sentiment and discourse
95
- - Policy analysis and market research
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  ## Model Limitations
98
 
99
  - Trained on Reddit data from specific subreddits
100
- - May not generalize to other platforms
 
101
  - Limited to English language content
 
6
  - food
7
  - climate-change
8
  - sustainability
9
+ - veganism-&-vegetarianism
10
  license: mit
11
  ---
12
 
13
+ # Veganism & Vegetarianism Classifier (Distilbert)
14
 
15
+ This model classifies content related to plant-based diets, sustainable food systems, and ethical eating. It analyzes discussions about veganism, vegetarianism, and sustainable food choices.
16
 
17
  ## Model Details
18
 
19
+ - **Model Type**: Distilbert
20
  - **Task**: Multilabel text classification
21
  - **Sector**: Veganism & Vegetarianism
22
+ - **Base Model**: Distilbert base uncased
23
  - **Number of Labels**: 7
24
  - **Training Data**: Reddit posts and comments from r/climate, r/climateactionplan, r/climatechange, r/climatechaos, r/climateco, r/climatecrisis, r/climatecrisiscanada, r/climatedisalarm, r/climatejobslist, r/climatejustice, r/climatememes, r/climatenews, r/climateoffensive, r/climatepolicy, r/climate_science
25
 
26
+ ## Training Methodology
27
+
28
+ This model was trained using **GPT-labeled data** from Reddit discussions. The training process involved:
29
+
30
+ 1. **Data Collection**: Reddit posts and comments from climate-related subreddits (2010-2023)
31
+ 2. **Regex Filtering**: Content was filtered using sector-specific regex patterns to identify relevant discussions
32
+ 3. **GPT Labeling**: Using GPT models to generate initial labels for training data
33
+ 4. **Data Splitting**: 80% training, 10% validation, 10% test split
34
+ 5. **Model Training**: Fine-tuning Distilbert on the labeled dataset with optimal threshold tuning
35
+ 6. **Validation**: Performance evaluation on held-out test sets using Jaccard similarity metrics
36
+
37
+ ### Training Data Sources
38
+ - **Source Subreddits**: r/climate, r/climateactionplan, r/climatechange, r/climatechaos, r/climateco, r/climatecrisis, r/climatecrisiscanada, r/climatedisalarm, r/climatejobslist, r/climatejustice, r/climatememes, r/climatenews, r/climateoffensive, r/climatepolicy, r/climate_science
39
+ - **Data Period**: 2010-2023
40
+ - **Focus Areas**: Plant-based nutrition, sustainable agriculture, ethical food choices, alternative proteins, food policy
41
+ - **Data Type**: Posts and comments filtered by regex patterns
42
+ - **Labeling Method**: GPT-assisted multilabel classification
43
+
44
+ ### Regex Patterns Used
45
+ The model was trained on content filtered using sector-specific regex patterns:
46
+
47
+ **Transport Sector (EV-related terms):**
48
+ - Strong patterns: ev, electric vehicle, evs, bev, tesla model, supercharger, gigafactory
49
+ - Weak patterns: electric car, charging station, tax credit, e-bike, tesla
50
+
51
+ **Housing Sector (Solar energy terms):**
52
+ - Strong patterns: rooftop solar, solar pv, pv panel, photovoltaics, solar array
53
+ - Weak patterns: solar panel, solar power, battery storage, powerwall, solar tax credit
54
+
55
+ **Food Sector (Plant-based diet terms):**
56
+ - Strong patterns: vegan, plant-based diet, veganism, vegetarian, beyond meat
57
+ - Weak patterns: red meat, dairy free, plant protein, almond milk, flexitarian
58
+
59
  ## Labels
60
 
61
+ The model predicts **7 labels** simultaneously (multilabel classification). Each text can be classified into multiple categories:
62
 
63
  1. **Animal Welfare**
64
  2. **Environmental Impact**
 
69
  7. **Taste And Convenience**
70
 
71
 
72
+ **⚠️ Important**: The order of labels in the output predictions corresponds exactly to the order listed above. When using the model, ensure your label list matches this order.
73
+
74
+ **💡 Tip**: For best performance, use the optimal thresholds provided in the "Optimal Thresholds" section below instead of the default 0.5 threshold.
75
+
76
+ **🔧 Threshold Optimization**: The optimal thresholds were computed using Jaccard similarity optimization on the validation set. For your own dataset, consider re-optimizing thresholds using the same method from `paper4_BERT_finetuning.py`.
77
+
78
  ## Usage
79
 
80
  ```python
81
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
82
+ import torch
83
+
84
+ # Load model and tokenizer
85
+ model_name = "sanchow/veganism_and_vegetarianism-distilbert-classifier"
86
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
87
+
88
+ # Load the custom MultilabelClassifier model
89
  import torch
90
  from transformers import DistilBertTokenizer
91
  import sys
92
  import os
93
 
94
+ # Add model directory to path and import the custom model class
 
95
  sys.path.append(model_name)
96
  from model_class import MultilabelClassifier
97
 
98
  # Load tokenizer
99
  tokenizer = DistilBertTokenizer.from_pretrained(model_name)
100
 
101
+ # Load the model weights with weights_only=False for compatibility
102
+ # Note: weights_only=False is required for PyTorch 2.6+ compatibility with numpy arrays in checkpoints
103
+ checkpoint = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu', weights_only=False)
104
  model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
105
  model.load_state_dict(checkpoint['model_state_dict'])
106
  model.eval()
 
109
  text = "Your text here"
110
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
111
  outputs = model(**inputs)
112
+ predictions = torch.sigmoid(outputs.logits)
113
 
114
  # Get label predictions
115
+ # IMPORTANT: The order of labels must match the training order exactly
116
  label_names = ['Animal Welfare', 'Environmental Impact', 'Health', 'Lab Grown And Alt Proteins', 'Psychology And Identity', 'Systemic Vs Individual Action', 'Taste And Convenience']
117
  for i, (label, score) in enumerate(zip(label_names, predictions[0])):
118
  if score > 0.5: # Threshold for positive classification
119
  print(f"{label}: {score:.3f}")
120
+
121
+ # For multilabel classification with optimal thresholds
122
+ # Note: Use optimal thresholds from model performance for better results
123
+ # See the "Optimal Thresholds" section in the model card for specific values
124
+ optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
125
+ for i, (label, score) in enumerate(zip(label_names, predictions[0])):
126
+ threshold = optimal_thresholds.get(label, 0.5) # Use optimal threshold or default to 0.5
127
+ if score > threshold:
128
+ print(f"{label}: {score:.3f} (threshold: {threshold:.3f})")
129
+
130
+ # For production use, consider re-optimizing thresholds on your validation data
131
+ # See the "Threshold Optimization" section for details
132
  ```
133
 
134
+ ## Applications
135
+
136
+ This model is designed for:
137
+ - **Content Analysis**: Analyzing social media discussions about veganism & vegetarianism
138
+ - **Research**: Understanding public sentiment and discourse around Plant-based nutrition, sustainable agriculture, ethical food choices, alternative proteins, food policy
139
+ - **Policy Analysis**: Identifying key topics and concerns in food discussions
140
+ - **Market Research**: Tracking trends and interests in veganism & vegetarianism
141
+
142
+
143
+ ## Performance Metrics
144
+
145
+ **Best Model Performance (Epoch 5):**
146
+ - Micro Jaccard Score: 0.5584
147
+ - Macro Jaccard Score: 0.6710
148
+ - F1 Score: 0.8906
149
+ - Accuracy: 0.8906
150
+ - Precision: 0.8906
151
+ - Recall: 0.8906
152
+
153
+ **Dataset Sizes:**
154
+ - Training Samples: ~600 per sector
155
+ - Validation Samples: ~150 per sector
156
+ - Test Samples: ~150 per sector
157
+ - Total GPT-labeled samples: ~900 per sector
158
+
159
+
160
 
161
  ## Optimal Thresholds
162
 
163
+ For best performance, use these thresholds to convert continuous scores to binary predictions:
164
+
165
  - **Animal Welfare**: 0.481
166
  - **Environmental Impact**: 0.459
167
  - **Health**: 0.201
 
170
  - **Systemic Vs Individual Action**: 0.375
171
  - **Taste And Convenience**: 0.664
172
 
173
+ **Usage with optimal thresholds:**
174
+ ```python
175
+ # Define optimal thresholds for this model
176
+ optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
177
 
178
+ # Apply thresholds to get binary predictions
179
+ for i, (label, score) in enumerate(zip(label_names, predictions[0])):
180
+ threshold = optimal_thresholds.get(label, 0.5)
181
+ if score > threshold:
182
+ print(f"{label}: {score:.3f} (threshold: {threshold:.3f})")
183
+ ```
184
 
 
 
 
 
185
 
186
+ ## Threshold Optimization
187
+
188
+ The optimal thresholds provided above were computed using **Jaccard similarity optimization** on the validation dataset. This method finds the best threshold for each label that maximizes the Jaccard similarity between predicted and true labels.
189
+
190
+ ### Optimization Method Used
191
+
192
+ The thresholds were optimized using the `find_optimal_thresholds_jaccard_global` function from `paper4_multilabel_threshold_optimizer.py`, which:
193
+
194
+ 1. **Grid Search**: Tests threshold values from 0.1 to 0.9 in 0.05 increments
195
+ 2. **Jaccard Optimization**: Maximizes micro-averaged Jaccard similarity
196
+ 3. **Per-Label Optimization**: Finds optimal threshold for each label independently
197
+ 4. **Global Optimization**: Considers the overall multilabel performance
198
+
199
+ ### Re-optimizing for Your Dataset
200
+
201
+ For best results on your specific dataset, consider re-optimizing thresholds:
202
 
203
+ ```python
204
+ from paper4_multilabel_threshold_optimizer import find_optimal_thresholds_jaccard_global
205
+
206
+ # Load your validation data
207
+ validation_data = {
208
+ 'texts': ['your text 1', 'your text 2', ...],
209
+ 'true_labels': [['label1', 'label2'], ['label3'], ...]
210
+ }
211
+
212
+ # Create sector models dict (as expected by the optimizer)
213
+ sector_models = {
214
+ 'your_sector': {
215
+ 'model': model,
216
+ 'tokenizer': tokenizer,
217
+ 'label_names': label_names
218
+ }
219
+ }
220
+
221
+ # Find optimal thresholds for your data
222
+ optimal_thresholds = find_optimal_thresholds_jaccard_global(
223
+ sector_models,
224
+ validation_data,
225
+ device=device
226
+ )
227
+
228
+ # Use the optimized thresholds
229
+ thresholds = optimal_thresholds['your_sector']
230
+ ```
231
+
232
+ ### Alternative Optimization Methods
233
+
234
+ You can also implement other threshold optimization strategies:
235
+
236
+ - **F1-score optimization**: Maximize F1-score instead of Jaccard
237
+ - **Precision/Recall trade-off**: Optimize for specific precision/recall requirements
238
+ - **Cost-sensitive optimization**: Weight different types of errors differently
239
+
240
+ ## Citation
241
+
242
+ If you use this model in your research, please cite:
243
+
244
+ ```bibtex
245
+ @misc{veganism_and_vegetarianism_distilbert_classifier,
246
+ title={Veganism & Vegetarianism Classifier for Climate Change Analysis},
247
+ author={Sandeep Chowdhary},
248
+ year={2024},
249
+ publisher={Hugging Face},
250
+ journal={Hugging Face Hub},
251
+ howpublished={\url{https://huggingface.co/sanchow/veganism_and_vegetarianism-distilbert-classifier}},
252
+ }
253
+ ```
254
 
255
  ## Model Limitations
256
 
257
  - Trained on Reddit data from specific subreddits
258
+ - May not generalize to other platforms or contexts
259
+ - Performance depends on the quality of GPT-generated labels
260
  - Limited to English language content