|
|
|
|
|
import json |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
import pandas as pd |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
|
|
|
ERROR_CATEGORIES = { |
|
|
"multi_step": "Chain-of-thought reasoning failures in multi-step problems", |
|
|
"percentage": "Percentage and ratio calculations", |
|
|
"logic": "Logical reasoning and problem setup failures", |
|
|
"unit_conversion": "Measurement and unit conversion errors" |
|
|
} |
|
|
|
|
|
|
|
|
ERROR_CATEGORY_COLORS = { |
|
|
"multi_step": "red", |
|
|
"percentage": "blue", |
|
|
"logic": "orange", |
|
|
"unit_conversion": "purple" |
|
|
} |
|
|
|
|
|
def refined_classify_error_type(result: Dict[str, Any]) -> str: |
|
|
""" |
|
|
More precise error classification that distinguishes between error types |
|
|
""" |
|
|
question = result["question"].lower() |
|
|
ground_truth = str(result["ground_truth"]).lower() |
|
|
predicted = str(result["predicted_answer"]).lower() |
|
|
|
|
|
|
|
|
if any(term in question for term in ['%', 'percent', 'percentage']): |
|
|
return "percentage" |
|
|
|
|
|
|
|
|
if any(unit in question for unit in ['pound', 'pounds', 'ounce', 'ounces', 'gallon', 'gallons', |
|
|
'mile', 'miles', 'hour', 'hours', 'minute', 'minutes', |
|
|
'second', 'seconds', 'dollar', 'dollars', 'cent', 'cents']): |
|
|
return "unit_conversion" |
|
|
|
|
|
|
|
|
|
|
|
multi_step_indicators = [ |
|
|
len(question.split()) > 30, |
|
|
any(connector in question for connector in ['if', 'then', 'after', 'before', 'when', 'however', 'since']), |
|
|
question.count('and') > 3, |
|
|
any(phrase in question for phrase in ['first', 'then', 'next', 'finally', 'after that']), |
|
|
any(term in question for term in ['each', 'every', 'per', 'total', 'combined']) |
|
|
] |
|
|
|
|
|
|
|
|
simple_arithmetic_indicators = [ |
|
|
question.count('+') > 0, |
|
|
question.count('-') > 0, |
|
|
question.count('*') > 0, |
|
|
question.count('/') > 0, |
|
|
question.count('times') > 0, |
|
|
question.count('plus') > 0, |
|
|
question.count('minus') > 0 |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
complexity_score = sum(multi_step_indicators) |
|
|
if complexity_score >= 3 and sum(simple_arithmetic_indicators) <= 2: |
|
|
return "multi_step" |
|
|
|
|
|
|
|
|
return "logic" |
|
|
|
|
|
def load_results(filename: str = "few_shot_results_errors_only.json") -> List[Dict[str, Any]]: |
|
|
"""Load the evaluation results from JSON file""" |
|
|
try: |
|
|
with open(filename, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if "results" in data: |
|
|
results = data.get("results", []) |
|
|
|
|
|
for result in results: |
|
|
result["is_correct"] = False |
|
|
return results |
|
|
else: |
|
|
|
|
|
return data |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"Error: File {filename} not found.") |
|
|
return [] |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: Invalid JSON format in {filename}.") |
|
|
return [] |
|
|
|
|
|
def create_comprehensive_visualization(results: List[Dict], num_samples: int): |
|
|
"""Create a comprehensive visualization showing all errors with proper categorization""" |
|
|
|
|
|
|
|
|
errors = [r for r in results if not r.get("is_correct", True)] |
|
|
error_categories = {category: [] for category in ERROR_CATEGORIES} |
|
|
|
|
|
for error in errors: |
|
|
category = refined_classify_error_type(error) |
|
|
error_categories[category].append(error) |
|
|
|
|
|
|
|
|
categories = list(error_categories.keys()) |
|
|
counts = [len(error_categories[cat]) for cat in categories] |
|
|
percentages = [count/len(errors)*100 if errors else 0 for count in counts] |
|
|
|
|
|
|
|
|
error_details = [] |
|
|
for category, errors_list in error_categories.items(): |
|
|
for error in errors_list: |
|
|
error_details.append({ |
|
|
"sample_index": error.get("index", 0), |
|
|
"category": category, |
|
|
"ground_truth": error["ground_truth"], |
|
|
"predicted": error["predicted_answer"], |
|
|
"question_preview": error["question"][:50] + "..." if len(error["question"]) > 50 else error["question"] |
|
|
}) |
|
|
|
|
|
error_df = pd.DataFrame(error_details) |
|
|
|
|
|
|
|
|
fig = make_subplots( |
|
|
rows=2, cols=2, |
|
|
subplot_titles=( |
|
|
'Error Category Distribution', |
|
|
'Error Samples by Index', |
|
|
'Ground Truth vs Predicted Values', |
|
|
'Error Analysis Summary' |
|
|
), |
|
|
specs=[[{"type": "pie"}, {"type": "scatter"}], |
|
|
[{"type": "scatter"}, {"type": "table"}]] |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Pie( |
|
|
labels=categories, |
|
|
values=counts, |
|
|
hole=0.4, |
|
|
textinfo='label+value+percent', |
|
|
hoverinfo='label+value+percent', |
|
|
name="Error Types", |
|
|
marker=dict(colors=[ERROR_CATEGORY_COLORS[cat] for cat in categories]) |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
for category in categories: |
|
|
category_errors = error_df[error_df['category'] == category] |
|
|
if not category_errors.empty: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=category_errors['sample_index'], |
|
|
y=[1] * len(category_errors), |
|
|
mode='markers', |
|
|
marker=dict(size=12, color=ERROR_CATEGORY_COLORS[category]), |
|
|
name=category, |
|
|
text=category_errors['question_preview'], |
|
|
hoverinfo='text+x+y+name' |
|
|
), |
|
|
row=1, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
if not error_df.empty: |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=error_df['ground_truth'].astype(float), |
|
|
y=error_df['predicted'].astype(float), |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
size=10, |
|
|
color=[ERROR_CATEGORY_COLORS[cat] for cat in error_df['category']], |
|
|
opacity=0.7 |
|
|
), |
|
|
text=error_df['category'] + ": Sample " + error_df['sample_index'].astype(str), |
|
|
hoverinfo='text+x+y', |
|
|
name='GT vs Predicted' |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
max_val = max(max(error_df['ground_truth'].astype(float)), max(error_df['predicted'].astype(float))) + 10 |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[0, max_val], |
|
|
y=[0, max_val], |
|
|
mode='lines', |
|
|
line=dict(dash='dash', color='gray'), |
|
|
name='Ideal', |
|
|
showlegend=False |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
summary_data = [] |
|
|
for category in categories: |
|
|
for error in error_categories[category]: |
|
|
summary_data.append([ |
|
|
error.get("index", "N/A"), |
|
|
category, |
|
|
error["ground_truth"], |
|
|
error["predicted_answer"], |
|
|
"β" if error["ground_truth"] == error["predicted_answer"] else "β" |
|
|
]) |
|
|
|
|
|
if summary_data: |
|
|
fig.add_trace( |
|
|
go.Table( |
|
|
header=dict(values=['Sample', 'Category', 'Ground Truth', 'Predicted', 'Correct']), |
|
|
cells=dict(values=[ |
|
|
[row[0] for row in summary_data], |
|
|
[row[1] for row in summary_data], |
|
|
[row[2] for row in summary_data], |
|
|
[row[3] for row in summary_data], |
|
|
[row[4] for row in summary_data] |
|
|
]), |
|
|
name='Error Details' |
|
|
), |
|
|
row=2, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title=f"Comprehensive Error Analysis - {num_samples} Samples ({len(errors)} Errors)", |
|
|
height=1000, |
|
|
width=1400, |
|
|
showlegend=True |
|
|
) |
|
|
|
|
|
fig.update_xaxes(title_text="Sample Index", row=1, col=2) |
|
|
fig.update_yaxes(title_text="", row=1, col=2, showticklabels=False) |
|
|
fig.update_xaxes(title_text="Ground Truth", row=2, col=1) |
|
|
fig.update_yaxes(title_text="Predicted", row=2, col=1) |
|
|
|
|
|
return fig, error_categories |
|
|
|
|
|
def generate_detailed_error_report(error_categories: Dict, num_samples: int): |
|
|
"""Generate a detailed report with analysis of each error category""" |
|
|
|
|
|
total_errors = sum(len(errors) for errors in error_categories.values()) |
|
|
accuracy = (num_samples - total_errors) / num_samples * 100 if num_samples > 0 else 0 |
|
|
|
|
|
report = ["# Detailed Error Analysis Report", ""] |
|
|
report.append(f"**Total Samples**: {num_samples}") |
|
|
report.append(f"**Total Errors**: {total_errors}") |
|
|
report.append(f"**Overall Accuracy**: {accuracy:.1f}%") |
|
|
report.append("") |
|
|
|
|
|
for category, errors in error_categories.items(): |
|
|
if errors: |
|
|
report.append(f"## {category.upper()} Errors ({len(errors)} errors)") |
|
|
report.append("") |
|
|
|
|
|
for i, error in enumerate(errors, 1): |
|
|
report.append(f"### Error {i}: Sample {error.get('index', 'N/A')}") |
|
|
report.append("**Question:**") |
|
|
report.append(f"> {error['question']}") |
|
|
report.append("") |
|
|
report.append("**Ground Truth:**") |
|
|
report.append(f"`{error['ground_truth']}`") |
|
|
report.append("") |
|
|
report.append("**Model Prediction:**") |
|
|
report.append(f"`{error['predicted_answer']}`") |
|
|
report.append("") |
|
|
report.append("**Error Analysis:**") |
|
|
report.append(analyze_specific_error(error, category)) |
|
|
report.append("") |
|
|
report.append("**Suggested Improvement:**") |
|
|
report.append(suggest_improvement(error, category)) |
|
|
report.append("---") |
|
|
report.append("") |
|
|
|
|
|
return "\n".join(report) |
|
|
|
|
|
def analyze_specific_error(error: Dict, category: str) -> str: |
|
|
"""Provide specific analysis for each error""" |
|
|
question = error["question"] |
|
|
generated = error.get("generated_text", "") |
|
|
|
|
|
if category == "percentage": |
|
|
return "Percentage calculation error - likely misunderstanding of percentage relationships or incorrect application of percentage formulas." |
|
|
|
|
|
elif category == "multi_step": |
|
|
|
|
|
if "then" not in generated.lower() and "so" not in generated.lower(): |
|
|
return "Missing logical connectors - model failed to show step-by-step reasoning process." |
|
|
elif generated.count('\n') < 3: |
|
|
return "Insufficient step breakdown - model attempted to solve in too few steps." |
|
|
else: |
|
|
return "Complex multi-step reasoning failure - model understood individual steps but failed to combine them correctly." |
|
|
|
|
|
elif category == "unit_conversion": |
|
|
return "Unit conversion error - likely misunderstanding of measurement units or incorrect conversion between units." |
|
|
|
|
|
return "General reasoning error - model struggled with the problem structure." |
|
|
|
|
|
def suggest_improvement(error: Dict, category: str) -> str: |
|
|
"""Provide specific improvement suggestions""" |
|
|
if category == "percentage": |
|
|
return "Train on more percentage word problems with varied contexts. Implement percentage-specific prompting strategies." |
|
|
|
|
|
elif category == "multi_step": |
|
|
return "Use chain-of-thought fine-tuning. Break complex problems into sub-tasks. Add intermediate supervision during training." |
|
|
|
|
|
elif category == "unit_conversion": |
|
|
return "Practice unit conversion problems with step-by-step solutions. Focus on measurement and currency conversion scenarios. Train on real-world unit conversion applications." |
|
|
|
|
|
return "General reasoning training with diverse problem types and increased context understanding." |
|
|
|
|
|
def main(): |
|
|
|
|
|
results = load_results('few_shot_results_errors_only.json') |
|
|
|
|
|
if not results: |
|
|
print("β No results found. Exiting.") |
|
|
return |
|
|
|
|
|
num_samples = len(results) |
|
|
print(f"π Performing comprehensive error analysis on {num_samples} error samples...") |
|
|
|
|
|
fig, error_categories = create_comprehensive_visualization(results, num_samples) |
|
|
|
|
|
|
|
|
fig.write_html("enhanced_error_analysis.html") |
|
|
print("πΎ Enhanced visualization saved to enhanced_error_analysis.html") |
|
|
|
|
|
|
|
|
report = generate_detailed_error_report(error_categories, num_samples) |
|
|
with open("detailed_error_report.md", "w") as f: |
|
|
f.write(report) |
|
|
print("π Detailed report saved to detailed_error_report.md") |
|
|
|
|
|
|
|
|
total_errors = sum(len(errors) for errors in error_categories.values()) |
|
|
print(f"\nπ Error Summary:") |
|
|
for category, errors in error_categories.items(): |
|
|
if errors: |
|
|
percentage = len(errors)/total_errors*100 if total_errors > 0 else 0 |
|
|
print(f" {category.upper()}: {len(errors)} errors ({percentage:.1f}%)") |
|
|
|
|
|
print(f" Overall Accuracy: {((num_samples - total_errors) / num_samples * 100):.1f}%" if num_samples > 0 else "N/A") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |