File size: 14,297 Bytes
e953abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
# enhanced_error_analysis.py
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List, Dict, Any
# Error categories with detailed descriptions
ERROR_CATEGORIES = {
"multi_step": "Chain-of-thought reasoning failures in multi-step problems",
"percentage": "Percentage and ratio calculations",
"logic": "Logical reasoning and problem setup failures",
"unit_conversion": "Measurement and unit conversion errors"
}
# Define colors for each category
ERROR_CATEGORY_COLORS = {
"multi_step": "red",
"percentage": "blue",
"logic": "orange",
"unit_conversion": "purple"
}
def refined_classify_error_type(result: Dict[str, Any]) -> str:
"""
More precise error classification that distinguishes between error types
"""
question = result["question"].lower()
ground_truth = str(result["ground_truth"]).lower()
predicted = str(result["predicted_answer"]).lower()
# Check for percentage problems
if any(term in question for term in ['%', 'percent', 'percentage']):
return "percentage"
# Check for unit conversion issues - prioritize this over multi-step
if any(unit in question for unit in ['pound', 'pounds', 'ounce', 'ounces', 'gallon', 'gallons',
'mile', 'miles', 'hour', 'hours', 'minute', 'minutes',
'second', 'seconds', 'dollar', 'dollars', 'cent', 'cents']):
return "unit_conversion"
# Check for multi-step problems (complex wording)
# Exclude simple arithmetic problems that happen to have "twice", "half", etc.
multi_step_indicators = [
len(question.split()) > 30, # Longer questions are more complex
any(connector in question for connector in ['if', 'then', 'after', 'before', 'when', 'however', 'since']),
question.count('and') > 3, # More connections indicate complexity
any(phrase in question for phrase in ['first', 'then', 'next', 'finally', 'after that']),
any(term in question for term in ['each', 'every', 'per', 'total', 'combined'])
]
# Additional check to exclude simple problems with basic operations
simple_arithmetic_indicators = [
question.count('+') > 0,
question.count('-') > 0,
question.count('*') > 0,
question.count('/') > 0,
question.count('times') > 0,
question.count('plus') > 0,
question.count('minus') > 0
]
# Only classify as multi-step if it has multiple complexity indicators
# but doesn't look like simple arithmetic
complexity_score = sum(multi_step_indicators)
if complexity_score >= 3 and sum(simple_arithmetic_indicators) <= 2:
return "multi_step"
# Default to logic if other categories don't fit
return "logic"
def load_results(filename: str = "few_shot_results_errors_only.json") -> List[Dict[str, Any]]:
"""Load the evaluation results from JSON file"""
try:
with open(filename, 'r') as f:
data = json.load(f)
# Handle both formats: full results and error-only results
if "results" in data:
results = data.get("results", [])
# For error-only files, ensure all are marked as incorrect
for result in results:
result["is_correct"] = False
return results
else:
# Assume it's a list of results directly
return data
except FileNotFoundError:
print(f"Error: File {filename} not found.")
return []
except json.JSONDecodeError:
print(f"Error: Invalid JSON format in {filename}.")
return []
def create_comprehensive_visualization(results: List[Dict], num_samples: int):
"""Create a comprehensive visualization showing all errors with proper categorization"""
# Categorize errors
errors = [r for r in results if not r.get("is_correct", True)]
error_categories = {category: [] for category in ERROR_CATEGORIES}
for error in errors:
category = refined_classify_error_type(error)
error_categories[category].append(error)
# Prepare data for visualization
categories = list(error_categories.keys())
counts = [len(error_categories[cat]) for cat in categories]
percentages = [count/len(errors)*100 if errors else 0 for count in counts]
# Create detailed error list for scatter plot
error_details = []
for category, errors_list in error_categories.items():
for error in errors_list:
error_details.append({
"sample_index": error.get("index", 0),
"category": category,
"ground_truth": error["ground_truth"],
"predicted": error["predicted_answer"],
"question_preview": error["question"][:50] + "..." if len(error["question"]) > 50 else error["question"]
})
error_df = pd.DataFrame(error_details)
# Create visualization
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
'Error Category Distribution',
'Error Samples by Index',
'Ground Truth vs Predicted Values',
'Error Analysis Summary'
),
specs=[[{"type": "pie"}, {"type": "scatter"}],
[{"type": "scatter"}, {"type": "table"}]]
)
# Pie chart - Error distribution with specified colors
fig.add_trace(
go.Pie(
labels=categories,
values=counts,
hole=0.4,
textinfo='label+value+percent',
hoverinfo='label+value+percent',
name="Error Types",
marker=dict(colors=[ERROR_CATEGORY_COLORS[cat] for cat in categories])
),
row=1, col=1
)
# Scatter plot - Error samples by index
for category in categories:
category_errors = error_df[error_df['category'] == category]
if not category_errors.empty:
fig.add_trace(
go.Scatter(
x=category_errors['sample_index'],
y=[1] * len(category_errors),
mode='markers',
marker=dict(size=12, color=ERROR_CATEGORY_COLORS[category]),
name=category,
text=category_errors['question_preview'],
hoverinfo='text+x+y+name'
),
row=1, col=2
)
# Scatter plot - Ground truth vs predicted
if not error_df.empty:
fig.add_trace(
go.Scatter(
x=error_df['ground_truth'].astype(float),
y=error_df['predicted'].astype(float),
mode='markers',
marker=dict(
size=10,
color=[ERROR_CATEGORY_COLORS[cat] for cat in error_df['category']],
opacity=0.7
),
text=error_df['category'] + ": Sample " + error_df['sample_index'].astype(str),
hoverinfo='text+x+y',
name='GT vs Predicted'
),
row=2, col=1
)
# Add ideal line
max_val = max(max(error_df['ground_truth'].astype(float)), max(error_df['predicted'].astype(float))) + 10
fig.add_trace(
go.Scatter(
x=[0, max_val],
y=[0, max_val],
mode='lines',
line=dict(dash='dash', color='gray'),
name='Ideal',
showlegend=False
),
row=2, col=1
)
# Table - Error summary
summary_data = []
for category in categories:
for error in error_categories[category]:
summary_data.append([
error.get("index", "N/A"),
category,
error["ground_truth"],
error["predicted_answer"],
"β" if error["ground_truth"] == error["predicted_answer"] else "β"
])
if summary_data:
fig.add_trace(
go.Table(
header=dict(values=['Sample', 'Category', 'Ground Truth', 'Predicted', 'Correct']),
cells=dict(values=[
[row[0] for row in summary_data],
[row[1] for row in summary_data],
[row[2] for row in summary_data],
[row[3] for row in summary_data],
[row[4] for row in summary_data]
]),
name='Error Details'
),
row=2, col=2
)
# Update layout
fig.update_layout(
title=f"Comprehensive Error Analysis - {num_samples} Samples ({len(errors)} Errors)",
height=1000,
width=1400,
showlegend=True
)
fig.update_xaxes(title_text="Sample Index", row=1, col=2)
fig.update_yaxes(title_text="", row=1, col=2, showticklabels=False)
fig.update_xaxes(title_text="Ground Truth", row=2, col=1)
fig.update_yaxes(title_text="Predicted", row=2, col=1)
return fig, error_categories
def generate_detailed_error_report(error_categories: Dict, num_samples: int):
"""Generate a detailed report with analysis of each error category"""
total_errors = sum(len(errors) for errors in error_categories.values())
accuracy = (num_samples - total_errors) / num_samples * 100 if num_samples > 0 else 0
report = ["# Detailed Error Analysis Report", ""]
report.append(f"**Total Samples**: {num_samples}")
report.append(f"**Total Errors**: {total_errors}")
report.append(f"**Overall Accuracy**: {accuracy:.1f}%")
report.append("")
for category, errors in error_categories.items():
if errors:
report.append(f"## {category.upper()} Errors ({len(errors)} errors)")
report.append("")
for i, error in enumerate(errors, 1):
report.append(f"### Error {i}: Sample {error.get('index', 'N/A')}")
report.append("**Question:**")
report.append(f"> {error['question']}")
report.append("")
report.append("**Ground Truth:**")
report.append(f"`{error['ground_truth']}`")
report.append("")
report.append("**Model Prediction:**")
report.append(f"`{error['predicted_answer']}`")
report.append("")
report.append("**Error Analysis:**")
report.append(analyze_specific_error(error, category))
report.append("")
report.append("**Suggested Improvement:**")
report.append(suggest_improvement(error, category))
report.append("---")
report.append("")
return "\n".join(report)
def analyze_specific_error(error: Dict, category: str) -> str:
"""Provide specific analysis for each error"""
question = error["question"]
generated = error.get("generated_text", "")
if category == "percentage":
return "Percentage calculation error - likely misunderstanding of percentage relationships or incorrect application of percentage formulas."
elif category == "multi_step":
# Check for specific multi-step failure patterns
if "then" not in generated.lower() and "so" not in generated.lower():
return "Missing logical connectors - model failed to show step-by-step reasoning process."
elif generated.count('\n') < 3:
return "Insufficient step breakdown - model attempted to solve in too few steps."
else:
return "Complex multi-step reasoning failure - model understood individual steps but failed to combine them correctly."
elif category == "unit_conversion":
return "Unit conversion error - likely misunderstanding of measurement units or incorrect conversion between units."
return "General reasoning error - model struggled with the problem structure."
def suggest_improvement(error: Dict, category: str) -> str:
"""Provide specific improvement suggestions"""
if category == "percentage":
return "Train on more percentage word problems with varied contexts. Implement percentage-specific prompting strategies."
elif category == "multi_step":
return "Use chain-of-thought fine-tuning. Break complex problems into sub-tasks. Add intermediate supervision during training."
elif category == "unit_conversion":
return "Practice unit conversion problems with step-by-step solutions. Focus on measurement and currency conversion scenarios. Train on real-world unit conversion applications."
return "General reasoning training with diverse problem types and increased context understanding."
def main():
# Load results
results = load_results('few_shot_results_errors_only.json')
if not results:
print("β No results found. Exiting.")
return
num_samples = len(results)
print(f"π Performing comprehensive error analysis on {num_samples} error samples...")
fig, error_categories = create_comprehensive_visualization(results, num_samples)
# Save visualization
fig.write_html("enhanced_error_analysis.html")
print("πΎ Enhanced visualization saved to enhanced_error_analysis.html")
# Generate detailed report
report = generate_detailed_error_report(error_categories, num_samples)
with open("detailed_error_report.md", "w") as f:
f.write(report)
print("π Detailed report saved to detailed_error_report.md")
# Print summary
total_errors = sum(len(errors) for errors in error_categories.values())
print(f"\nπ Error Summary:")
for category, errors in error_categories.items():
if errors:
percentage = len(errors)/total_errors*100 if total_errors > 0 else 0
print(f" {category.upper()}: {len(errors)} errors ({percentage:.1f}%)")
print(f" Overall Accuracy: {((num_samples - total_errors) / num_samples * 100):.1f}%" if num_samples > 0 else "N/A")
if __name__ == "__main__":
main() |