TroglodyteDerivations's picture
Upload 2 files
e953abb verified
# updated_error_analysis.py
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import pandas as pd
from typing import Dict, List, Any
import argparse
# Error categories with detailed descriptions
ERROR_CATEGORIES = {
"multi_step": "Chain-of-thought reasoning failures in multi-step problems",
"percentage": "Percentage and ratio calculations",
"logic": "Logical reasoning and problem setup failures",
"unit_conversion": "Measurement and unit conversion errors"
}
def classify_error_type(result: Dict[str, Any]) -> str:
"""
Classify the type of error based on the question, ground truth, and predicted answer
"""
question = result["question"].lower()
ground_truth = str(result["ground_truth"]).lower()
predicted = str(result["predicted_answer"]).lower()
# Check for percentage problems
if any(term in question for term in ['%', 'percent', 'percentage']):
return "percentage"
# Check for unit conversion issues - prioritize this over multi-step
if any(unit in question for unit in ['pound', 'pounds', 'ounce', 'ounces', 'gallon', 'gallons',
'mile', 'miles', 'hour', 'hours', 'minute', 'minutes',
'second', 'seconds', 'dollar', 'dollars', 'cent', 'cents']):
return "unit_conversion"
# Check for multi-step problems (complex wording)
# Exclude simple arithmetic problems that happen to have "twice", "half", etc.
multi_step_indicators = [
len(question.split()) > 30, # Longer questions are more complex
any(connector in question for connector in ['if', 'then', 'after', 'before', 'when', 'however', 'since']),
question.count('and') > 3, # More connections indicate complexity
any(phrase in question for phrase in ['first', 'then', 'next', 'finally', 'after that']),
any(term in question for term in ['each', 'every', 'per', 'total', 'combined'])
]
# Additional check to exclude simple problems with basic operations
simple_arithmetic_indicators = [
question.count('+') > 0,
question.count('-') > 0,
question.count('*') > 0,
question.count('/') > 0,
question.count('times') > 0,
question.count('plus') > 0,
question.count('minus') > 0
]
# Only classify as multi-step if it has multiple complexity indicators
# but doesn't look like simple arithmetic
complexity_score = sum(multi_step_indicators)
if complexity_score >= 3 and sum(simple_arithmetic_indicators) <= 2:
return "multi_step"
# Default to logic if other categories don't fit
return "logic"
def load_results(filename: str = "few_shot_results_errors_only.json") -> List[Dict[str, Any]]:
"""Load the evaluation results from JSON file"""
try:
with open(filename, 'r') as f:
data = json.load(f)
# Handle both formats: full results and error-only results
if "results" in data:
results = data.get("results", [])
# For error-only files, ensure all are marked as incorrect
for result in results:
result["is_correct"] = False
return results
else:
# Assume it's a list of results directly
return data
except FileNotFoundError:
print(f"Error: File {filename} not found.")
return []
except json.JSONDecodeError:
print(f"Error: Invalid JSON format in {filename}.")
return []
def analyze_errors(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Perform comprehensive error analysis"""
total_samples = len(results)
# For error-only analysis, all samples are incorrect
correct_count = 0
incorrect_count = total_samples
# Categorize errors
error_categories = {category: [] for category in ERROR_CATEGORIES}
error_categories["correct"] = []
for result in results:
if result["is_correct"]:
error_categories["correct"].append(result)
else:
error_type = classify_error_type(result)
error_categories[error_type].append(result)
# Calculate statistics
accuracy = (correct_count / total_samples) * 100 if total_samples > 0 else 0
category_stats = {
"total_samples": total_samples,
"correct_count": correct_count,
"incorrect_count": incorrect_count,
"accuracy": accuracy,
"category_counts": {cat: len(errors) for cat, errors in error_categories.items()},
"category_percentages": {
cat: (len(errors) / total_samples * 100) if total_samples > 0 else 0
for cat, errors in error_categories.items()
},
"detailed_errors": error_categories
}
return category_stats
def create_visualizations(stats: Dict[str, Any], num_samples: int):
"""Create interactive Plotly visualizations with specified color scheme"""
# Prepare data for visualizations
categories = list(ERROR_CATEGORIES.keys()) + ["correct"]
counts = [stats["category_counts"].get(cat, 0) for cat in categories]
percentages = [stats["category_percentages"].get(cat, 0) for cat in categories]
# Create subplots
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
f'Error Distribution (n={num_samples})',
'Error Category Breakdown',
'Error Percentage by Category',
'Sample Index vs Error Type'
),
specs=[[{"type": "pie"}, {"type": "bar"}],
[{"type": "scatter"}, {"type": "scatter"}]]
)
# Define color scheme
color_map = {
'multi_step': 'red',
'percentage': 'blue',
'logic': 'orange',
'unit_conversion': 'purple',
'correct': 'green'
}
# Pie chart - Error distribution
error_categories_pie = [cat for cat in categories if cat != "correct" and stats["category_counts"].get(cat, 0) > 0]
error_counts_pie = [stats["category_counts"].get(cat, 0) for cat in error_categories_pie]
pie_colors = [color_map.get(cat, 'gray') for cat in error_categories_pie]
fig.add_trace(
go.Pie(
labels=error_categories_pie,
values=error_counts_pie,
name="Error Types",
hole=0.4,
textinfo='label+percent',
hoverinfo='label+value+percent',
marker=dict(colors=pie_colors),
showlegend=False
),
row=1, col=1
)
# Bar chart - Category counts
non_zero_categories = [cat for cat in categories if cat != "correct" and stats["category_counts"].get(cat, 0) > 0]
non_zero_counts = [stats["category_counts"].get(cat, 0) for cat in non_zero_categories]
non_zero_percentages = [stats["category_percentages"].get(cat, 0) for cat in non_zero_categories]
bar_colors = [color_map.get(cat, 'gray') for cat in non_zero_categories]
fig.add_trace(
go.Bar(
x=non_zero_categories,
y=non_zero_counts,
name="Count by Category",
marker_color=bar_colors,
text=[f"{count}<br>{percent:.1f}%" for count, percent in zip(non_zero_counts, non_zero_percentages)],
textposition='auto',
hoverinfo='x+y'
),
row=1, col=2
)
# Scatter plot - Error percentage by category
error_df = pd.DataFrame({
'Category': categories,
'Percentage': percentages,
'Count': counts
})
error_df = error_df[error_df['Count'] > 0] # Only show categories with samples
scatter_colors = [color_map.get(cat, 'gray') for cat in error_df['Category']]
fig.add_trace(
go.Scatter(
x=error_df['Category'],
y=error_df['Percentage'],
mode='markers+text',
marker=dict(
size=error_df['Count']*2 + 10,
color=scatter_colors,
opacity=0.8
),
text=error_df['Count'],
textposition='middle center',
name='Error Percentage',
hoverinfo='text',
hovertext=[f"{cat}: {pct:.1f}% ({cnt} samples)" for cat, pct, cnt in
zip(error_df['Category'], error_df['Percentage'], error_df['Count'])]
),
row=2, col=1
)
# Scatter plot - Sample index vs error type
all_results = []
for category in ERROR_CATEGORIES:
for result in stats["detailed_errors"][category]:
all_results.append((result, category))
# Sort by sample index
all_results.sort(key=lambda x: x[0].get("index", 0))
sample_indices = [result[0].get("index", i+1) for i, result in enumerate(all_results)]
error_types = [result[1] for result in all_results]
error_colors = [color_map.get(error_type, 'gray') for error_type in error_types]
fig.add_trace(
go.Scatter(
x=sample_indices,
y=[1] * len(all_results), # All are errors, so y=1
mode='markers',
marker=dict(
color=error_colors,
size=12,
opacity=0.7
),
name='Error Types by Sample',
hoverinfo='x+y+text',
hovertext=[f"Sample {result[0].get('index', 'N/A')}: {result[1]}" for result in all_results]
),
row=2, col=2
)
# Update layout
fig.update_layout(
title=f"Symbolic-Math-Qwen2.5-1.5B-LoRA Error Analysis (n={num_samples} errors)",
height=1000,
width=1200,
showlegend=False
)
fig.update_xaxes(title_text="Error Categories", row=1, col=2)
fig.update_yaxes(title_text="Number of Errors", row=1, col=2)
fig.update_xaxes(title_text="Categories", row=2, col=1)
fig.update_yaxes(title_text="Percentage (%)", row=2, col=1)
fig.update_xaxes(title_text="Sample Index", row=2, col=2)
fig.update_yaxes(title_text="Error Type", row=2, col=2, tickvals=[1], ticktext=["Errors"])
return fig
def generate_detailed_report(stats: Dict[str, Any]):
"""Generate a detailed text report of the analysis"""
report = []
report.append("=" * 60)
report.append("SYMBOLIC-MATH-QWEN2.5-1.5B-LoRA ERROR ANALYSIS REPORT")
report.append("=" * 60)
report.append(f"Total Error Samples: {stats['total_samples']}")
report.append(f"Overall Accuracy: {stats['accuracy']:.2f}%")
report.append("")
report.append("ERROR CATEGORY BREAKDOWN:")
report.append("-" * 40)
for category, count in stats['category_counts'].items():
if category != "correct" and count > 0:
percentage = stats['category_percentages'][category]
report.append(f"{category.upper():<20}: {count:>3} errors ({percentage:>5.1f}%)")
report.append("")
report.append("RECOMMENDATIONS:")
report.append("-" * 40)
# Generate recommendations based on error patterns
if stats['category_counts'].get('percentage', 0) > 0:
report.append("β€’ Add percentage calculation training examples")
report.append("β€’ Implement percentage-specific prompting strategies")
report.append("β€’ Train on more percentage word problems with varied contexts")
if stats['category_counts'].get('multi_step', 0) > 0:
report.append("β€’ Focus on chain-of-thought reasoning training")
report.append("β€’ Break down complex problems into sub-steps")
report.append("β€’ Add intermediate supervision during training")
if stats['category_counts'].get('logic', 0) > 0:
report.append("β€’ Improve logical reasoning capabilities")
report.append("β€’ Train on problem setup and structure understanding")
report.append("β€’ General reasoning training with diverse problem types")
if stats['category_counts'].get('unit_conversion', 0) > 0:
report.append("β€’ Practice unit conversion problems")
report.append("β€’ Focus on measurement and currency conversions")
report.append("β€’ Train on real-world unit conversion scenarios")
return "\n".join(report)
def main():
parser = argparse.ArgumentParser(description="Analyze Symbolic-Math-Qwen2.5-1.5B-LoRA errors on GSM8K dataset")
parser.add_argument("--samples", type=int, default=16, help="Number of error samples analyzed")
parser.add_argument("--input", type=str, default="few_shot_results_errors_only.json", help="Input JSON file")
parser.add_argument("--output", type=str, default="error_analysis.html", help="Output HTML file")
args = parser.parse_args()
print(f"πŸ” Analyzing {args.samples} error samples...")
print(f"πŸ“ Loading results from {args.input}")
# Load and analyze results
results = load_results(args.input)
if not results:
print("❌ No results found. Exiting.")
return
# Limit to specified number of samples
results = results[:args.samples]
print("πŸ“Š Performing error analysis...")
stats = analyze_errors(results)
# Generate visualizations
print("πŸ“ˆ Creating visualizations...")
fig = create_visualizations(stats, args.samples)
# Save visualizations
fig.write_html(args.output)
print(f"πŸ’Ύ Visualizations saved to {args.output}")
# Generate and print detailed report
report = generate_detailed_report(stats)
print("\n" + report)
# Save report to file
report_filename = f"error_analysis_report_{args.samples}_errors.txt"
with open(report_filename, 'w') as f:
f.write(report)
print(f"πŸ“ Detailed report saved to {report_filename}")
print(f"\n🎯 Error analysis complete! Found {stats['total_samples']} error samples")
print("πŸ‘‰ Open the HTML file in your browser to view interactive visualizations")
if __name__ == "__main__":
main()