smol_lm_3b / Generate Error Analysis Visualizations /updated_enhanced_error_analysis.py
TroglodyteDerivations's picture
Upload 2 files
e953abb verified
# enhanced_error_analysis.py
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List, Dict, Any
# Error categories with detailed descriptions
ERROR_CATEGORIES = {
"multi_step": "Chain-of-thought reasoning failures in multi-step problems",
"percentage": "Percentage and ratio calculations",
"logic": "Logical reasoning and problem setup failures",
"unit_conversion": "Measurement and unit conversion errors"
}
# Define colors for each category
ERROR_CATEGORY_COLORS = {
"multi_step": "red",
"percentage": "blue",
"logic": "orange",
"unit_conversion": "purple"
}
def refined_classify_error_type(result: Dict[str, Any]) -> str:
"""
More precise error classification that distinguishes between error types
"""
question = result["question"].lower()
ground_truth = str(result["ground_truth"]).lower()
predicted = str(result["predicted_answer"]).lower()
# Check for percentage problems
if any(term in question for term in ['%', 'percent', 'percentage']):
return "percentage"
# Check for unit conversion issues - prioritize this over multi-step
if any(unit in question for unit in ['pound', 'pounds', 'ounce', 'ounces', 'gallon', 'gallons',
'mile', 'miles', 'hour', 'hours', 'minute', 'minutes',
'second', 'seconds', 'dollar', 'dollars', 'cent', 'cents']):
return "unit_conversion"
# Check for multi-step problems (complex wording)
# Exclude simple arithmetic problems that happen to have "twice", "half", etc.
multi_step_indicators = [
len(question.split()) > 30, # Longer questions are more complex
any(connector in question for connector in ['if', 'then', 'after', 'before', 'when', 'however', 'since']),
question.count('and') > 3, # More connections indicate complexity
any(phrase in question for phrase in ['first', 'then', 'next', 'finally', 'after that']),
any(term in question for term in ['each', 'every', 'per', 'total', 'combined'])
]
# Additional check to exclude simple problems with basic operations
simple_arithmetic_indicators = [
question.count('+') > 0,
question.count('-') > 0,
question.count('*') > 0,
question.count('/') > 0,
question.count('times') > 0,
question.count('plus') > 0,
question.count('minus') > 0
]
# Only classify as multi-step if it has multiple complexity indicators
# but doesn't look like simple arithmetic
complexity_score = sum(multi_step_indicators)
if complexity_score >= 3 and sum(simple_arithmetic_indicators) <= 2:
return "multi_step"
# Default to logic if other categories don't fit
return "logic"
def load_results(filename: str = "few_shot_results_errors_only.json") -> List[Dict[str, Any]]:
"""Load the evaluation results from JSON file"""
try:
with open(filename, 'r') as f:
data = json.load(f)
# Handle both formats: full results and error-only results
if "results" in data:
results = data.get("results", [])
# For error-only files, ensure all are marked as incorrect
for result in results:
result["is_correct"] = False
return results
else:
# Assume it's a list of results directly
return data
except FileNotFoundError:
print(f"Error: File {filename} not found.")
return []
except json.JSONDecodeError:
print(f"Error: Invalid JSON format in {filename}.")
return []
def create_comprehensive_visualization(results: List[Dict], num_samples: int):
"""Create a comprehensive visualization showing all errors with proper categorization"""
# Categorize errors
errors = [r for r in results if not r.get("is_correct", True)]
error_categories = {category: [] for category in ERROR_CATEGORIES}
for error in errors:
category = refined_classify_error_type(error)
error_categories[category].append(error)
# Prepare data for visualization
categories = list(error_categories.keys())
counts = [len(error_categories[cat]) for cat in categories]
percentages = [count/len(errors)*100 if errors else 0 for count in counts]
# Create detailed error list for scatter plot
error_details = []
for category, errors_list in error_categories.items():
for error in errors_list:
error_details.append({
"sample_index": error.get("index", 0),
"category": category,
"ground_truth": error["ground_truth"],
"predicted": error["predicted_answer"],
"question_preview": error["question"][:50] + "..." if len(error["question"]) > 50 else error["question"]
})
error_df = pd.DataFrame(error_details)
# Create visualization
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
'Error Category Distribution',
'Error Samples by Index',
'Ground Truth vs Predicted Values',
'Error Analysis Summary'
),
specs=[[{"type": "pie"}, {"type": "scatter"}],
[{"type": "scatter"}, {"type": "table"}]]
)
# Pie chart - Error distribution with specified colors
fig.add_trace(
go.Pie(
labels=categories,
values=counts,
hole=0.4,
textinfo='label+value+percent',
hoverinfo='label+value+percent',
name="Error Types",
marker=dict(colors=[ERROR_CATEGORY_COLORS[cat] for cat in categories])
),
row=1, col=1
)
# Scatter plot - Error samples by index
for category in categories:
category_errors = error_df[error_df['category'] == category]
if not category_errors.empty:
fig.add_trace(
go.Scatter(
x=category_errors['sample_index'],
y=[1] * len(category_errors),
mode='markers',
marker=dict(size=12, color=ERROR_CATEGORY_COLORS[category]),
name=category,
text=category_errors['question_preview'],
hoverinfo='text+x+y+name'
),
row=1, col=2
)
# Scatter plot - Ground truth vs predicted
if not error_df.empty:
fig.add_trace(
go.Scatter(
x=error_df['ground_truth'].astype(float),
y=error_df['predicted'].astype(float),
mode='markers',
marker=dict(
size=10,
color=[ERROR_CATEGORY_COLORS[cat] for cat in error_df['category']],
opacity=0.7
),
text=error_df['category'] + ": Sample " + error_df['sample_index'].astype(str),
hoverinfo='text+x+y',
name='GT vs Predicted'
),
row=2, col=1
)
# Add ideal line
max_val = max(max(error_df['ground_truth'].astype(float)), max(error_df['predicted'].astype(float))) + 10
fig.add_trace(
go.Scatter(
x=[0, max_val],
y=[0, max_val],
mode='lines',
line=dict(dash='dash', color='gray'),
name='Ideal',
showlegend=False
),
row=2, col=1
)
# Table - Error summary
summary_data = []
for category in categories:
for error in error_categories[category]:
summary_data.append([
error.get("index", "N/A"),
category,
error["ground_truth"],
error["predicted_answer"],
"βœ“" if error["ground_truth"] == error["predicted_answer"] else "βœ—"
])
if summary_data:
fig.add_trace(
go.Table(
header=dict(values=['Sample', 'Category', 'Ground Truth', 'Predicted', 'Correct']),
cells=dict(values=[
[row[0] for row in summary_data],
[row[1] for row in summary_data],
[row[2] for row in summary_data],
[row[3] for row in summary_data],
[row[4] for row in summary_data]
]),
name='Error Details'
),
row=2, col=2
)
# Update layout
fig.update_layout(
title=f"Comprehensive Error Analysis - {num_samples} Samples ({len(errors)} Errors)",
height=1000,
width=1400,
showlegend=True
)
fig.update_xaxes(title_text="Sample Index", row=1, col=2)
fig.update_yaxes(title_text="", row=1, col=2, showticklabels=False)
fig.update_xaxes(title_text="Ground Truth", row=2, col=1)
fig.update_yaxes(title_text="Predicted", row=2, col=1)
return fig, error_categories
def generate_detailed_error_report(error_categories: Dict, num_samples: int):
"""Generate a detailed report with analysis of each error category"""
total_errors = sum(len(errors) for errors in error_categories.values())
accuracy = (num_samples - total_errors) / num_samples * 100 if num_samples > 0 else 0
report = ["# Detailed Error Analysis Report", ""]
report.append(f"**Total Samples**: {num_samples}")
report.append(f"**Total Errors**: {total_errors}")
report.append(f"**Overall Accuracy**: {accuracy:.1f}%")
report.append("")
for category, errors in error_categories.items():
if errors:
report.append(f"## {category.upper()} Errors ({len(errors)} errors)")
report.append("")
for i, error in enumerate(errors, 1):
report.append(f"### Error {i}: Sample {error.get('index', 'N/A')}")
report.append("**Question:**")
report.append(f"> {error['question']}")
report.append("")
report.append("**Ground Truth:**")
report.append(f"`{error['ground_truth']}`")
report.append("")
report.append("**Model Prediction:**")
report.append(f"`{error['predicted_answer']}`")
report.append("")
report.append("**Error Analysis:**")
report.append(analyze_specific_error(error, category))
report.append("")
report.append("**Suggested Improvement:**")
report.append(suggest_improvement(error, category))
report.append("---")
report.append("")
return "\n".join(report)
def analyze_specific_error(error: Dict, category: str) -> str:
"""Provide specific analysis for each error"""
question = error["question"]
generated = error.get("generated_text", "")
if category == "percentage":
return "Percentage calculation error - likely misunderstanding of percentage relationships or incorrect application of percentage formulas."
elif category == "multi_step":
# Check for specific multi-step failure patterns
if "then" not in generated.lower() and "so" not in generated.lower():
return "Missing logical connectors - model failed to show step-by-step reasoning process."
elif generated.count('\n') < 3:
return "Insufficient step breakdown - model attempted to solve in too few steps."
else:
return "Complex multi-step reasoning failure - model understood individual steps but failed to combine them correctly."
elif category == "unit_conversion":
return "Unit conversion error - likely misunderstanding of measurement units or incorrect conversion between units."
return "General reasoning error - model struggled with the problem structure."
def suggest_improvement(error: Dict, category: str) -> str:
"""Provide specific improvement suggestions"""
if category == "percentage":
return "Train on more percentage word problems with varied contexts. Implement percentage-specific prompting strategies."
elif category == "multi_step":
return "Use chain-of-thought fine-tuning. Break complex problems into sub-tasks. Add intermediate supervision during training."
elif category == "unit_conversion":
return "Practice unit conversion problems with step-by-step solutions. Focus on measurement and currency conversion scenarios. Train on real-world unit conversion applications."
return "General reasoning training with diverse problem types and increased context understanding."
def main():
# Load results
results = load_results('few_shot_results_errors_only.json')
if not results:
print("❌ No results found. Exiting.")
return
num_samples = len(results)
print(f"πŸ” Performing comprehensive error analysis on {num_samples} error samples...")
fig, error_categories = create_comprehensive_visualization(results, num_samples)
# Save visualization
fig.write_html("enhanced_error_analysis.html")
print("πŸ’Ύ Enhanced visualization saved to enhanced_error_analysis.html")
# Generate detailed report
report = generate_detailed_error_report(error_categories, num_samples)
with open("detailed_error_report.md", "w") as f:
f.write(report)
print("πŸ“ Detailed report saved to detailed_error_report.md")
# Print summary
total_errors = sum(len(errors) for errors in error_categories.values())
print(f"\nπŸ“Š Error Summary:")
for category, errors in error_categories.items():
if errors:
percentage = len(errors)/total_errors*100 if total_errors > 0 else 0
print(f" {category.upper()}: {len(errors)} errors ({percentage:.1f}%)")
print(f" Overall Accuracy: {((num_samples - total_errors) / num_samples * 100):.1f}%" if num_samples > 0 else "N/A")
if __name__ == "__main__":
main()