participatory-planner / TRAINING_STRATEGY.md
thadillo
Phases 1-3: Database schema, text processing, analyzer updates
71797a4
|
raw
history blame
7.07 kB
# Training Strategy Guide for Participatory Planning Classifier
## Current Performance (as of Oct 2025)
- **Dataset**: 60 examples (~42 train / 9 val / 9 test)
- **Current Best**: Head-only training - **66.7% accuracy**
- **Baseline**: ~60% (zero-shot BART-mnli)
- **Challenge**: Only 6.7% improvement - model is **underfitting**
## Recommended Training Strategies (Ranked)
### πŸ₯‡ **Strategy 1: LoRA with Conservative Settings**
**Best for: Your current 60-example dataset**
```yaml
Configuration:
training_mode: lora
lora_rank: 4-8 # Start small!
lora_alpha: 8-16 # 2x rank
lora_dropout: 0.2 # High dropout to prevent overfitting
learning_rate: 1e-4 # Conservative
num_epochs: 5-7 # Watch for overfitting
batch_size: 4 # Smaller batches
```
**Expected Accuracy**: 70-80%
**Why it works:**
- More capacity than head-only (~500K params with r=4)
- Still parameter-efficient enough for 60 examples
- Dropout prevents overfitting
**Try this first!** Your head-only results show you need more model capacity.
---
### πŸ₯ˆ **Strategy 2: Data Augmentation + LoRA**
**Best for: Improving beyond 80% accuracy**
**Step 1: Augment your dataset to 150-200 examples**
Methods:
1. **Paraphrasing** (use GPT/Claude):
```python
# For each example:
"We need better public transit"
β†’ "Public transportation should be improved"
β†’ "Transit system requires enhancement"
```
2. **Back-translation**:
English β†’ Spanish β†’ English (creates natural variations)
3. **Template-based**:
Create templates for each category and fill with variations
**Step 2: Train LoRA (r=8-16) on augmented data**
- Expected Accuracy: 80-90%
---
### πŸ₯‰ **Strategy 3: Two-Stage Progressive Training**
**Best for: Maximizing performance with limited data**
1. **Stage 1**: Head-only (warm-up)
- 3 epochs
- Initialize the classification head
2. **Stage 2**: LoRA fine-tuning
- r=4, low learning rate
- Build on head-only initialization
---
### πŸ”§ **Strategy 4: Optimize Category Definitions**
**May help with zero-shot AND fine-tuning**
Your categories might be too similar. Consider:
**Current Categories:**
- Vision vs Objectives (both forward-looking)
- Problem vs Directives (both constraints)
**Better Definitions:**
```python
CATEGORIES = {
'Vision': {
'name': 'Vision & Aspirations',
'description': 'Long-term future state, desired outcomes, what success looks like',
'keywords': ['future', 'aspire', 'imagine', 'dream', 'ideal']
},
'Problem': {
'name': 'Current Problems',
'description': 'Existing issues, frustrations, barriers, root causes',
'keywords': ['problem', 'issue', 'challenge', 'barrier', 'broken']
},
'Objectives': {
'name': 'Specific Goals',
'description': 'Measurable targets, concrete milestones, quantifiable outcomes',
'keywords': ['increase', 'reduce', 'achieve', 'target', 'by 2030']
},
'Directives': {
'name': 'Constraints & Requirements',
'description': 'Must-haves, non-negotiables, compliance requirements',
'keywords': ['must', 'required', 'mandate', 'comply', 'regulation']
},
'Values': {
'name': 'Principles & Values',
'description': 'Core beliefs, ethical guidelines, guiding principles',
'keywords': ['equity', 'sustainability', 'justice', 'fairness', 'inclusive']
},
'Actions': {
'name': 'Concrete Actions',
'description': 'Specific steps, interventions, activities to implement',
'keywords': ['build', 'create', 'implement', 'install', 'construct']
}
}
```
---
## Alternative Base Models to Consider
### **DeBERTa-v3-base** (Better for Classification)
```python
# In app/analyzer.py
model_name = "microsoft/deberta-v3-base"
# Size: 184M params (vs BART's 400M)
# Often outperforms BART for classification
```
### **DistilRoBERTa** (Faster, Lighter)
```python
model_name = "distilroberta-base"
# Size: 82M params
# 2x faster, 60% smaller
# Good accuracy
```
### **XLM-RoBERTa-base** (Multilingual)
```python
model_name = "xlm-roberta-base"
# If you have multilingual submissions
```
---
## Data Collection Strategy
**Current**: 60 examples β†’ **Target**: 150+ examples
### How to get more data:
1. **Active Learning** (Built into your system!)
- Deploy current model
- Admin reviews and corrects predictions
- Automatically builds training set
2. **Historical Data**
- Import past participatory planning submissions
- Manual labeling (15 min for 50 examples)
3. **Synthetic Generation** (Use GPT-4)
```
Prompt: "Generate 10 participatory planning submissions
that express VISION for urban transportation"
```
4. **Crowdsourcing**
- Mturk or internal team
- Label 100 examples: ~$20-50
---
## Performance Targets
| Dataset Size | Method | Expected Accuracy | Time to Train |
|-------------|--------|------------------|---------------|
| 60 | Head-only | 65-70% ❌ Current | 2 min |
| 60 | LoRA (r=4) | 70-80% βœ… Try next | 5 min |
| 150 | LoRA (r=8) | 80-85% ⭐ Goal | 10 min |
| 300+ | LoRA (r=16) | 85-90% 🎯 Ideal | 20 min |
---
## Immediate Action Plan
### Week 1: Low-Hanging Fruit
1. βœ… Train with LoRA (r=4, epochs=5)
2. βœ… Compare to head-only baseline
3. βœ… Check per-category F1 scores
### Week 2: Data Expansion
4. Collect 50 more examples (aim for balance)
5. Use data augmentation (paraphrase 60 β†’ 120)
6. Retrain LoRA (r=8)
### Week 3: Optimization
7. Try DeBERTa-v3-base as base model
8. Fine-tune category descriptions
9. Deploy best model
---
## Debugging Low Performance
If accuracy stays below 75%:
### Check 1: Data Quality
```python
# Look for label conflicts
SELECT message, corrected_category, COUNT(*)
FROM training_examples
GROUP BY message
HAVING COUNT(DISTINCT corrected_category) > 1
```
### Check 2: Class Imbalance
- Ensure each category has 5-10+ examples
- Use weighted loss if imbalanced
### Check 3: Category Confusion
- Generate confusion matrix
- Merge categories that are frequently confused
(e.g., Vision + Objectives β†’ "Future Goals")
### Check 4: Text Quality
- Remove very short texts (< 5 words)
- Remove duplicates
- Check for non-English text
---
## Advanced: Ensemble Models
If single model plateaus at 80-85%:
1. Train 3 models with different seeds
2. Use voting or averaging
3. Typical boost: +3-5% accuracy
```python
# Pseudo-code
predictions = [
model1.predict(text),
model2.predict(text),
model3.predict(text)
]
final = most_common(predictions) # Voting
```
---
## Conclusion
**For your current 60 examples:**
1. 🎯 **DO**: Try LoRA with r=4-8 (conservative settings)
2. πŸ“ˆ **DO**: Collect 50-100 more examples
3. πŸ”„ **DO**: Try DeBERTa-v3 as alternative base model
4. ❌ **DON'T**: Use head-only (proven to underfit)
5. ❌ **DON'T**: Use full fine-tuning (will overfit)
**Expected outcome:** 70-85% accuracy (up from current 66.7%)
**Next milestone:** 150 examples β†’ 85%+ accuracy