Sandeep Chowdhary
commited on
Update model card
Browse files
README.md
CHANGED
|
@@ -12,210 +12,118 @@ license: mit
|
|
| 12 |
|
| 13 |
# Veganism & Vegetarianism Classifier (Distilbert)
|
| 14 |
|
| 15 |
-
This model classifies content related veganism, vegetarianism, and sustainable food choices.
|
| 16 |
|
| 17 |
## Model Details
|
| 18 |
|
| 19 |
-
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
-
|
| 23 |
-
-
|
| 24 |
-
-
|
| 25 |
-
|
| 26 |
-
### Training Data Sources
|
| 27 |
-
- **Source Subreddits**: r/climate, r/climateactionplan, r/climatechange, r/climatechaos, r/climateco, r/climatecrisis, r/climatecrisiscanada, r/climatedisalarm, r/climatejobslist, r/climatejustice, r/climatememes, r/climatenews, r/climateoffensive, r/climatepolicy, r/climate_science
|
| 28 |
-
- **Data Period**: 2010-2023
|
| 29 |
-
- **Focus Areas**: Plant-based nutrition, sustainable agriculture, ethical food choices, alternative proteins, food policy
|
| 30 |
-
- **Data Type**: Posts and comments filtered by regex patterns
|
| 31 |
-
- **Labeling Method**: GPT-assisted multilabel classification
|
| 32 |
|
|
|
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Labels
|
| 36 |
|
| 37 |
-
The model predicts
|
| 38 |
-
|
| 39 |
-
1. **Animal Welfare**
|
| 40 |
-
2. **Environmental Impact**
|
| 41 |
-
3. **Health**
|
| 42 |
-
4. **Lab Grown And Alt Proteins**
|
| 43 |
-
5. **Psychology And Identity**
|
| 44 |
-
6. **Systemic Vs Individual Action**
|
| 45 |
-
7. **Taste And Convenience**
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
**⚠️ Important**: The order of labels in the output predictions corresponds exactly to the order listed above. When using the model, ensure your label list matches this order.
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
**🔧 Threshold Optimization**: The optimal thresholds were computed using Jaccard similarity optimization on the validation set. For your own dataset, consider re-optimizing thresholds using the same method from `paper4_BERT_finetuning.py`.
|
| 53 |
|
| 54 |
## Usage
|
| 55 |
|
| 56 |
```python
|
| 57 |
import torch
|
| 58 |
-
import numpy as np
|
| 59 |
from transformers import DistilBertTokenizer
|
| 60 |
-
import sys
|
| 61 |
-
import os
|
| 62 |
import tempfile
|
| 63 |
from huggingface_hub import snapshot_download
|
| 64 |
|
| 65 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 66 |
|
| 67 |
-
def print_sorted_label_scores(label_scores):
|
| 68 |
-
# Sort label_scores dict by score descending
|
| 69 |
-
sorted_items = sorted(label_scores.items(), key=lambda x: x[1], reverse=True)
|
| 70 |
-
for label, score in sorted_items:
|
| 71 |
-
print(f" {label}: {score:.6f}")
|
| 72 |
-
|
| 73 |
# Download and load model
|
| 74 |
model_link = "sanchow/veganism_and_vegetarianism-distilbert-classifier"
|
| 75 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 76 |
-
snapshot_download(
|
| 77 |
-
repo_id=model_link,
|
| 78 |
-
local_dir=temp_dir,
|
| 79 |
-
local_dir_use_symlinks=False
|
| 80 |
-
)
|
| 81 |
|
| 82 |
-
# Import the model class
|
| 83 |
sys.path.insert(0, temp_dir)
|
| 84 |
from model_class import MultilabelClassifier
|
| 85 |
|
| 86 |
-
# Load tokenizer and model
|
| 87 |
tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
|
| 88 |
-
checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu'
|
| 89 |
model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
|
| 90 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 91 |
model.to(device)
|
| 92 |
model.eval()
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
]
|
| 102 |
-
|
| 103 |
-
print(f"\n{sector_info['name']} classifier results:\n")
|
| 104 |
-
for i, test_text in enumerate(examples):
|
| 105 |
-
inputs = tokenizer(
|
| 106 |
-
test_text,
|
| 107 |
-
return_tensors="pt",
|
| 108 |
-
truncation=True,
|
| 109 |
-
max_length=512,
|
| 110 |
-
padding=True
|
| 111 |
-
).to(device)
|
| 112 |
-
with torch.no_grad():
|
| 113 |
-
outputs = model(**inputs)
|
| 114 |
-
predictions = outputs.cpu().numpy() if isinstance(outputs, (tuple, list)) else outputs.cpu().numpy()
|
| 115 |
-
label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
|
| 116 |
-
print(f"Example {i+1}: '{test_text}'")
|
| 117 |
-
print("Predictions (all label scores, highest first):")
|
| 118 |
-
print_sorted_label_scores(label_scores)
|
| 119 |
-
print("-" * 40)
|
| 120 |
```
|
| 121 |
|
| 122 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
| 127 |
- F1 Score: 0.8906
|
| 128 |
- Accuracy: 0.8906
|
| 129 |
-
- Precision: 0.8906
|
| 130 |
-
- Recall: 0.8906
|
| 131 |
|
| 132 |
-
|
| 133 |
-
- Training Samples: ~600 per sector
|
| 134 |
-
- Validation Samples: ~150 per sector
|
| 135 |
-
- Test Samples: ~150 per sector
|
| 136 |
-
- Total GPT-labeled samples: ~900 per sector
|
| 137 |
|
| 138 |
|
| 139 |
|
| 140 |
## Optimal Thresholds
|
| 141 |
|
| 142 |
-
|
| 143 |
|
| 144 |
-
-
|
| 145 |
-
-
|
| 146 |
-
-
|
| 147 |
-
-
|
| 148 |
-
-
|
| 149 |
-
-
|
| 150 |
-
-
|
| 151 |
|
| 152 |
-
|
| 153 |
```python
|
| 154 |
-
# Define optimal thresholds for this model
|
| 155 |
optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
|
| 156 |
-
|
| 157 |
-
# Apply thresholds to get binary predictions
|
| 158 |
-
for i, (label, score) in enumerate(zip(label_names, predictions[0])):
|
| 159 |
threshold = optimal_thresholds.get(label, 0.5)
|
| 160 |
if score > threshold:
|
| 161 |
-
print(f"{label}: {score:.3f}
|
| 162 |
```
|
| 163 |
|
| 164 |
|
| 165 |
-
## Threshold Optimization
|
| 166 |
-
|
| 167 |
-
The optimal thresholds provided above were computed using **Jaccard similarity optimization** on the validation dataset. This method finds the best threshold for each label that maximizes the Jaccard similarity between predicted and true labels.
|
| 168 |
-
|
| 169 |
-
### Optimization Method Used
|
| 170 |
-
|
| 171 |
-
The thresholds were optimized using the `find_optimal_thresholds_jaccard_global` function from `paper4_multilabel_threshold_optimizer.py`, which:
|
| 172 |
-
|
| 173 |
-
1. **Grid Search**: Tests threshold values from 0.1 to 0.9 in 0.05 increments
|
| 174 |
-
2. **Jaccard Optimization**: Maximizes micro-averaged Jaccard similarity
|
| 175 |
-
3. **Per-Label Optimization**: Finds optimal threshold for each label independently
|
| 176 |
-
4. **Global Optimization**: Considers the overall multilabel performance
|
| 177 |
-
|
| 178 |
-
### Re-optimizing for Your Dataset
|
| 179 |
-
|
| 180 |
-
For best results on your specific dataset, consider re-optimizing thresholds:
|
| 181 |
-
|
| 182 |
-
```python
|
| 183 |
-
from paper4_multilabel_threshold_optimizer import find_optimal_thresholds_jaccard_global
|
| 184 |
-
|
| 185 |
-
# Load your validation data
|
| 186 |
-
validation_data = {
|
| 187 |
-
'texts': ['your text 1', 'your text 2', ...],
|
| 188 |
-
'true_labels': [['label1', 'label2'], ['label3'], ...]
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
# Create sector models dict (as expected by the optimizer)
|
| 192 |
-
sector_models = {
|
| 193 |
-
'your_sector': {
|
| 194 |
-
'model': model,
|
| 195 |
-
'tokenizer': tokenizer,
|
| 196 |
-
'label_names': label_names
|
| 197 |
-
}
|
| 198 |
-
}
|
| 199 |
-
|
| 200 |
-
# Find optimal thresholds for your data
|
| 201 |
-
optimal_thresholds = find_optimal_thresholds_jaccard_global(
|
| 202 |
-
sector_models,
|
| 203 |
-
validation_data,
|
| 204 |
-
device=device
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
# Use the optimized thresholds
|
| 208 |
-
thresholds = optimal_thresholds['your_sector']
|
| 209 |
-
```
|
| 210 |
-
|
| 211 |
-
### Alternative Optimization Methods
|
| 212 |
-
|
| 213 |
-
You can also implement other threshold optimization strategies:
|
| 214 |
-
|
| 215 |
-
- **F1-score optimization**: Maximize F1-score instead of Jaccard
|
| 216 |
-
- **Precision/Recall trade-off**: Optimize for specific precision/recall requirements
|
| 217 |
-
- **Cost-sensitive optimization**: Weight different types of errors differently
|
| 218 |
-
|
| 219 |
## Citation
|
| 220 |
|
| 221 |
If you use this model in your research, please cite:
|
|
@@ -231,9 +139,9 @@ If you use this model in your research, please cite:
|
|
| 231 |
}
|
| 232 |
```
|
| 233 |
|
| 234 |
-
##
|
| 235 |
|
| 236 |
- Trained on Reddit data from specific subreddits
|
| 237 |
-
- May not generalize to other platforms
|
| 238 |
-
- Performance depends on
|
| 239 |
-
- Limited to English
|
|
|
|
| 12 |
|
| 13 |
# Veganism & Vegetarianism Classifier (Distilbert)
|
| 14 |
|
| 15 |
+
This model classifies content related to plant-based diets, sustainable food systems, and ethical eating. It analyzes discussions about veganism, vegetarianism, and sustainable food choices.
|
| 16 |
|
| 17 |
## Model Details
|
| 18 |
|
| 19 |
+
- Model Type: Distilbert
|
| 20 |
+
- Task: Multilabel text classification
|
| 21 |
+
- Sector: Veganism & Vegetarianism
|
| 22 |
+
- Base Model: Distilbert base uncased
|
| 23 |
+
- Labels: 7
|
| 24 |
+
- Training Data: Reddit posts from climate subreddits (2010-2023)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
## Training
|
| 27 |
|
| 28 |
+
Trained on GPT-labeled Reddit data:
|
| 29 |
+
1. Data collection from climate subreddits
|
| 30 |
+
2. Regex filtering for sector-specific content
|
| 31 |
+
3. GPT labeling for multilabel classification
|
| 32 |
+
4. 80/10/10 train/validation/test split
|
| 33 |
+
5. Fine-tuning with threshold optimization
|
| 34 |
|
| 35 |
## Labels
|
| 36 |
|
| 37 |
+
The model predicts 7 labels simultaneously:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
1. Animal Welfare
|
| 40 |
+
2. Environmental Impact
|
| 41 |
+
3. Health
|
| 42 |
+
4. Lab Grown And Alt Proteins
|
| 43 |
+
5. Psychology And Identity
|
| 44 |
+
6. Systemic Vs Individual Action
|
| 45 |
+
7. Taste And Convenience
|
| 46 |
|
|
|
|
| 47 |
|
| 48 |
+
Note: Label order in predictions matches the order above.
|
|
|
|
|
|
|
| 49 |
|
| 50 |
## Usage
|
| 51 |
|
| 52 |
```python
|
| 53 |
import torch
|
|
|
|
| 54 |
from transformers import DistilBertTokenizer
|
|
|
|
|
|
|
| 55 |
import tempfile
|
| 56 |
from huggingface_hub import snapshot_download
|
| 57 |
|
| 58 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Download and load model
|
| 61 |
model_link = "sanchow/veganism_and_vegetarianism-distilbert-classifier"
|
| 62 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 63 |
+
snapshot_download(repo_id=model_link, local_dir=temp_dir, local_dir_use_symlinks=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
|
|
|
| 65 |
sys.path.insert(0, temp_dir)
|
| 66 |
from model_class import MultilabelClassifier
|
| 67 |
|
|
|
|
| 68 |
tokenizer = DistilBertTokenizer.from_pretrained(temp_dir)
|
| 69 |
+
checkpoint = torch.load(os.path.join(temp_dir, 'model.pt'), map_location='cpu')
|
| 70 |
model = MultilabelClassifier(checkpoint['model_name'], len(checkpoint['label_names']))
|
| 71 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 72 |
model.to(device)
|
| 73 |
model.eval()
|
| 74 |
+
|
| 75 |
+
# Predict
|
| 76 |
+
text = "Your text here"
|
| 77 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
predictions = model(**inputs).cpu().numpy()
|
| 80 |
+
|
| 81 |
+
# Get scores
|
| 82 |
+
label_scores = {label: float(score) for label, score in zip(checkpoint['label_names'], predictions[0])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
```
|
| 84 |
|
| 85 |
+
## Applications
|
| 86 |
+
|
| 87 |
+
- Content analysis of social media discussions
|
| 88 |
+
- Research on public sentiment and discourse
|
| 89 |
+
- Policy analysis of key topics and concerns
|
| 90 |
+
- Market research on trends and interests
|
| 91 |
+
|
| 92 |
|
| 93 |
+
## Performance
|
| 94 |
+
|
| 95 |
+
Best model performance:
|
| 96 |
+
- Micro Jaccard: 0.5584
|
| 97 |
+
- Macro Jaccard: 0.6710
|
| 98 |
- F1 Score: 0.8906
|
| 99 |
- Accuracy: 0.8906
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
Dataset: ~900 GPT-labeled samples per sector (600 train, 150 validation, 150 test)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
|
| 105 |
## Optimal Thresholds
|
| 106 |
|
| 107 |
+
Use these thresholds for best performance:
|
| 108 |
|
| 109 |
+
- Animal Welfare: 0.481
|
| 110 |
+
- Environmental Impact: 0.459
|
| 111 |
+
- Health: 0.201
|
| 112 |
+
- Lab Grown And Alt Proteins: 0.341
|
| 113 |
+
- Psychology And Identity: 0.525
|
| 114 |
+
- Systemic Vs Individual Action: 0.375
|
| 115 |
+
- Taste And Convenience: 0.664
|
| 116 |
|
| 117 |
+
Usage:
|
| 118 |
```python
|
|
|
|
| 119 |
optimal_thresholds = {'Animal Welfare': 0.48107979620047003, 'Environmental Impact': 0.45919171852850427, 'Health': 0.20115313966833437, 'Lab Grown And Alt Proteins': 0.3414601502146817, 'Psychology And Identity': 0.5246278637433214, 'Systemic Vs Individual Action': 0.37517437676211585, 'Taste And Convenience': 0.6635140143644325}
|
| 120 |
+
for label, score in zip(label_names, predictions[0]):
|
|
|
|
|
|
|
| 121 |
threshold = optimal_thresholds.get(label, 0.5)
|
| 122 |
if score > threshold:
|
| 123 |
+
print(f"{label}: {score:.3f}")
|
| 124 |
```
|
| 125 |
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
## Citation
|
| 128 |
|
| 129 |
If you use this model in your research, please cite:
|
|
|
|
| 139 |
}
|
| 140 |
```
|
| 141 |
|
| 142 |
+
## Limitations
|
| 143 |
|
| 144 |
- Trained on Reddit data from specific subreddits
|
| 145 |
+
- May not generalize to other platforms
|
| 146 |
+
- Performance depends on GPT-generated labels
|
| 147 |
+
- Limited to English content
|