|
|
|
|
|
import json |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
import plotly.express as px |
|
|
import pandas as pd |
|
|
from typing import Dict, List, Any |
|
|
import argparse |
|
|
|
|
|
|
|
|
ERROR_CATEGORIES = { |
|
|
"multi_step": "Chain-of-thought reasoning failures in multi-step problems", |
|
|
"percentage": "Percentage and ratio calculations", |
|
|
"logic": "Logical reasoning and problem setup failures", |
|
|
"unit_conversion": "Measurement and unit conversion errors" |
|
|
} |
|
|
|
|
|
def classify_error_type(result: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Classify the type of error based on the question, ground truth, and predicted answer |
|
|
""" |
|
|
question = result["question"].lower() |
|
|
ground_truth = str(result["ground_truth"]).lower() |
|
|
predicted = str(result["predicted_answer"]).lower() |
|
|
|
|
|
|
|
|
if any(term in question for term in ['%', 'percent', 'percentage']): |
|
|
return "percentage" |
|
|
|
|
|
|
|
|
if any(unit in question for unit in ['pound', 'pounds', 'ounce', 'ounces', 'gallon', 'gallons', |
|
|
'mile', 'miles', 'hour', 'hours', 'minute', 'minutes', |
|
|
'second', 'seconds', 'dollar', 'dollars', 'cent', 'cents']): |
|
|
return "unit_conversion" |
|
|
|
|
|
|
|
|
|
|
|
multi_step_indicators = [ |
|
|
len(question.split()) > 30, |
|
|
any(connector in question for connector in ['if', 'then', 'after', 'before', 'when', 'however', 'since']), |
|
|
question.count('and') > 3, |
|
|
any(phrase in question for phrase in ['first', 'then', 'next', 'finally', 'after that']), |
|
|
any(term in question for term in ['each', 'every', 'per', 'total', 'combined']) |
|
|
] |
|
|
|
|
|
|
|
|
simple_arithmetic_indicators = [ |
|
|
question.count('+') > 0, |
|
|
question.count('-') > 0, |
|
|
question.count('*') > 0, |
|
|
question.count('/') > 0, |
|
|
question.count('times') > 0, |
|
|
question.count('plus') > 0, |
|
|
question.count('minus') > 0 |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
complexity_score = sum(multi_step_indicators) |
|
|
if complexity_score >= 3 and sum(simple_arithmetic_indicators) <= 2: |
|
|
return "multi_step" |
|
|
|
|
|
|
|
|
return "logic" |
|
|
|
|
|
def load_results(filename: str = "few_shot_results_errors_only.json") -> List[Dict[str, Any]]: |
|
|
"""Load the evaluation results from JSON file""" |
|
|
try: |
|
|
with open(filename, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if "results" in data: |
|
|
results = data.get("results", []) |
|
|
|
|
|
for result in results: |
|
|
result["is_correct"] = False |
|
|
return results |
|
|
else: |
|
|
|
|
|
return data |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"Error: File {filename} not found.") |
|
|
return [] |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: Invalid JSON format in {filename}.") |
|
|
return [] |
|
|
|
|
|
def analyze_errors(results: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
"""Perform comprehensive error analysis""" |
|
|
total_samples = len(results) |
|
|
|
|
|
|
|
|
correct_count = 0 |
|
|
incorrect_count = total_samples |
|
|
|
|
|
|
|
|
error_categories = {category: [] for category in ERROR_CATEGORIES} |
|
|
error_categories["correct"] = [] |
|
|
|
|
|
for result in results: |
|
|
if result["is_correct"]: |
|
|
error_categories["correct"].append(result) |
|
|
else: |
|
|
error_type = classify_error_type(result) |
|
|
error_categories[error_type].append(result) |
|
|
|
|
|
|
|
|
accuracy = (correct_count / total_samples) * 100 if total_samples > 0 else 0 |
|
|
|
|
|
category_stats = { |
|
|
"total_samples": total_samples, |
|
|
"correct_count": correct_count, |
|
|
"incorrect_count": incorrect_count, |
|
|
"accuracy": accuracy, |
|
|
"category_counts": {cat: len(errors) for cat, errors in error_categories.items()}, |
|
|
"category_percentages": { |
|
|
cat: (len(errors) / total_samples * 100) if total_samples > 0 else 0 |
|
|
for cat, errors in error_categories.items() |
|
|
}, |
|
|
"detailed_errors": error_categories |
|
|
} |
|
|
|
|
|
return category_stats |
|
|
|
|
|
def create_visualizations(stats: Dict[str, Any], num_samples: int): |
|
|
"""Create interactive Plotly visualizations with specified color scheme""" |
|
|
|
|
|
|
|
|
categories = list(ERROR_CATEGORIES.keys()) + ["correct"] |
|
|
counts = [stats["category_counts"].get(cat, 0) for cat in categories] |
|
|
percentages = [stats["category_percentages"].get(cat, 0) for cat in categories] |
|
|
|
|
|
|
|
|
fig = make_subplots( |
|
|
rows=2, cols=2, |
|
|
subplot_titles=( |
|
|
f'Error Distribution (n={num_samples})', |
|
|
'Error Category Breakdown', |
|
|
'Error Percentage by Category', |
|
|
'Sample Index vs Error Type' |
|
|
), |
|
|
specs=[[{"type": "pie"}, {"type": "bar"}], |
|
|
[{"type": "scatter"}, {"type": "scatter"}]] |
|
|
) |
|
|
|
|
|
|
|
|
color_map = { |
|
|
'multi_step': 'red', |
|
|
'percentage': 'blue', |
|
|
'logic': 'orange', |
|
|
'unit_conversion': 'purple', |
|
|
'correct': 'green' |
|
|
} |
|
|
|
|
|
|
|
|
error_categories_pie = [cat for cat in categories if cat != "correct" and stats["category_counts"].get(cat, 0) > 0] |
|
|
error_counts_pie = [stats["category_counts"].get(cat, 0) for cat in error_categories_pie] |
|
|
pie_colors = [color_map.get(cat, 'gray') for cat in error_categories_pie] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Pie( |
|
|
labels=error_categories_pie, |
|
|
values=error_counts_pie, |
|
|
name="Error Types", |
|
|
hole=0.4, |
|
|
textinfo='label+percent', |
|
|
hoverinfo='label+value+percent', |
|
|
marker=dict(colors=pie_colors), |
|
|
showlegend=False |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
non_zero_categories = [cat for cat in categories if cat != "correct" and stats["category_counts"].get(cat, 0) > 0] |
|
|
non_zero_counts = [stats["category_counts"].get(cat, 0) for cat in non_zero_categories] |
|
|
non_zero_percentages = [stats["category_percentages"].get(cat, 0) for cat in non_zero_categories] |
|
|
bar_colors = [color_map.get(cat, 'gray') for cat in non_zero_categories] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=non_zero_categories, |
|
|
y=non_zero_counts, |
|
|
name="Count by Category", |
|
|
marker_color=bar_colors, |
|
|
text=[f"{count}<br>{percent:.1f}%" for count, percent in zip(non_zero_counts, non_zero_percentages)], |
|
|
textposition='auto', |
|
|
hoverinfo='x+y' |
|
|
), |
|
|
row=1, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
error_df = pd.DataFrame({ |
|
|
'Category': categories, |
|
|
'Percentage': percentages, |
|
|
'Count': counts |
|
|
}) |
|
|
error_df = error_df[error_df['Count'] > 0] |
|
|
scatter_colors = [color_map.get(cat, 'gray') for cat in error_df['Category']] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=error_df['Category'], |
|
|
y=error_df['Percentage'], |
|
|
mode='markers+text', |
|
|
marker=dict( |
|
|
size=error_df['Count']*2 + 10, |
|
|
color=scatter_colors, |
|
|
opacity=0.8 |
|
|
), |
|
|
text=error_df['Count'], |
|
|
textposition='middle center', |
|
|
name='Error Percentage', |
|
|
hoverinfo='text', |
|
|
hovertext=[f"{cat}: {pct:.1f}% ({cnt} samples)" for cat, pct, cnt in |
|
|
zip(error_df['Category'], error_df['Percentage'], error_df['Count'])] |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
for category in ERROR_CATEGORIES: |
|
|
for result in stats["detailed_errors"][category]: |
|
|
all_results.append((result, category)) |
|
|
|
|
|
|
|
|
all_results.sort(key=lambda x: x[0].get("index", 0)) |
|
|
|
|
|
sample_indices = [result[0].get("index", i+1) for i, result in enumerate(all_results)] |
|
|
error_types = [result[1] for result in all_results] |
|
|
error_colors = [color_map.get(error_type, 'gray') for error_type in error_types] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=sample_indices, |
|
|
y=[1] * len(all_results), |
|
|
mode='markers', |
|
|
marker=dict( |
|
|
color=error_colors, |
|
|
size=12, |
|
|
opacity=0.7 |
|
|
), |
|
|
name='Error Types by Sample', |
|
|
hoverinfo='x+y+text', |
|
|
hovertext=[f"Sample {result[0].get('index', 'N/A')}: {result[1]}" for result in all_results] |
|
|
), |
|
|
row=2, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title=f"Symbolic-Math-Qwen2.5-1.5B-LoRA Error Analysis (n={num_samples} errors)", |
|
|
height=1000, |
|
|
width=1200, |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
fig.update_xaxes(title_text="Error Categories", row=1, col=2) |
|
|
fig.update_yaxes(title_text="Number of Errors", row=1, col=2) |
|
|
fig.update_xaxes(title_text="Categories", row=2, col=1) |
|
|
fig.update_yaxes(title_text="Percentage (%)", row=2, col=1) |
|
|
fig.update_xaxes(title_text="Sample Index", row=2, col=2) |
|
|
fig.update_yaxes(title_text="Error Type", row=2, col=2, tickvals=[1], ticktext=["Errors"]) |
|
|
|
|
|
return fig |
|
|
|
|
|
def generate_detailed_report(stats: Dict[str, Any]): |
|
|
"""Generate a detailed text report of the analysis""" |
|
|
report = [] |
|
|
|
|
|
report.append("=" * 60) |
|
|
report.append("SYMBOLIC-MATH-QWEN2.5-1.5B-LoRA ERROR ANALYSIS REPORT") |
|
|
report.append("=" * 60) |
|
|
report.append(f"Total Error Samples: {stats['total_samples']}") |
|
|
report.append(f"Overall Accuracy: {stats['accuracy']:.2f}%") |
|
|
report.append("") |
|
|
|
|
|
report.append("ERROR CATEGORY BREAKDOWN:") |
|
|
report.append("-" * 40) |
|
|
for category, count in stats['category_counts'].items(): |
|
|
if category != "correct" and count > 0: |
|
|
percentage = stats['category_percentages'][category] |
|
|
report.append(f"{category.upper():<20}: {count:>3} errors ({percentage:>5.1f}%)") |
|
|
|
|
|
report.append("") |
|
|
report.append("RECOMMENDATIONS:") |
|
|
report.append("-" * 40) |
|
|
|
|
|
|
|
|
if stats['category_counts'].get('percentage', 0) > 0: |
|
|
report.append("β’ Add percentage calculation training examples") |
|
|
report.append("β’ Implement percentage-specific prompting strategies") |
|
|
report.append("β’ Train on more percentage word problems with varied contexts") |
|
|
|
|
|
if stats['category_counts'].get('multi_step', 0) > 0: |
|
|
report.append("β’ Focus on chain-of-thought reasoning training") |
|
|
report.append("β’ Break down complex problems into sub-steps") |
|
|
report.append("β’ Add intermediate supervision during training") |
|
|
|
|
|
if stats['category_counts'].get('logic', 0) > 0: |
|
|
report.append("β’ Improve logical reasoning capabilities") |
|
|
report.append("β’ Train on problem setup and structure understanding") |
|
|
report.append("β’ General reasoning training with diverse problem types") |
|
|
|
|
|
if stats['category_counts'].get('unit_conversion', 0) > 0: |
|
|
report.append("β’ Practice unit conversion problems") |
|
|
report.append("β’ Focus on measurement and currency conversions") |
|
|
report.append("β’ Train on real-world unit conversion scenarios") |
|
|
|
|
|
return "\n".join(report) |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Analyze Symbolic-Math-Qwen2.5-1.5B-LoRA errors on GSM8K dataset") |
|
|
parser.add_argument("--samples", type=int, default=16, help="Number of error samples analyzed") |
|
|
parser.add_argument("--input", type=str, default="few_shot_results_errors_only.json", help="Input JSON file") |
|
|
parser.add_argument("--output", type=str, default="error_analysis.html", help="Output HTML file") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"π Analyzing {args.samples} error samples...") |
|
|
print(f"π Loading results from {args.input}") |
|
|
|
|
|
|
|
|
results = load_results(args.input) |
|
|
if not results: |
|
|
print("β No results found. Exiting.") |
|
|
return |
|
|
|
|
|
|
|
|
results = results[:args.samples] |
|
|
|
|
|
print("π Performing error analysis...") |
|
|
stats = analyze_errors(results) |
|
|
|
|
|
|
|
|
print("π Creating visualizations...") |
|
|
fig = create_visualizations(stats, args.samples) |
|
|
|
|
|
|
|
|
fig.write_html(args.output) |
|
|
print(f"πΎ Visualizations saved to {args.output}") |
|
|
|
|
|
|
|
|
report = generate_detailed_report(stats) |
|
|
print("\n" + report) |
|
|
|
|
|
|
|
|
report_filename = f"error_analysis_report_{args.samples}_errors.txt" |
|
|
with open(report_filename, 'w') as f: |
|
|
f.write(report) |
|
|
print(f"π Detailed report saved to {report_filename}") |
|
|
|
|
|
print(f"\nπ― Error analysis complete! Found {stats['total_samples']} error samples") |
|
|
print("π Open the HTML file in your browser to view interactive visualizations") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |