""" Dataset Visualization Module for Jigsaw Toxic Comment Classification """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from wordcloud import WordCloud from typing import Tuple import re import streamlit as st # Set style for better-looking plots sns.set_style("whitegrid") plt.rcParams['figure.figsize'] = (12, 6) def load_dataset(file_path: str) -> pd.DataFrame: """Load the train.csv dataset""" try: df = pd.read_csv(file_path) return df except Exception as e: st.error(f"Error loading dataset: {str(e)}") return None def prepare_data(df: pd.DataFrame) -> Tuple[pd.Series, pd.Series, dict]: """ Prepare data for visualization Returns: toxic_texts, non_toxic_texts, label_counts """ # Get label columns label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # Calculate label frequencies label_counts = df[label_columns].sum().to_dict() # Create binary column for any toxicity df['any_toxic'] = df[label_columns].max(axis=1) # Separate toxic and non-toxic texts toxic_df = df[df['any_toxic'] == 1] non_toxic_df = df[df['any_toxic'] == 0] # Sample for word clouds if dataset is too large max_samples = 5000 if len(toxic_df) > max_samples: toxic_df = toxic_df.sample(n=max_samples, random_state=42) if len(non_toxic_df) > max_samples: non_toxic_df = non_toxic_df.sample(n=max_samples, random_state=42) # Combine text toxic_texts = ' '.join(toxic_df['comment_text'].astype(str)) non_toxic_texts = ' '.join(non_toxic_df['comment_text'].astype(str)) return toxic_texts, non_toxic_texts, label_counts def clean_text_for_wordcloud(text: str) -> str: """Clean text for word cloud generation""" # Remove URLs text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) # Remove special characters but keep spaces text = re.sub(r'[^\w\s]', '', text) # Convert to lowercase text = text.lower() return text def create_label_frequency_chart(label_counts: dict): """Create a bar chart showing label frequencies""" labels = list(label_counts.keys()) counts = list(label_counts.values()) plt.figure(figsize=(10, 6)) bars = plt.bar(labels, counts, color=['#ff6b6b', '#4ecdc4', '#45b7d1', '#f9ca24', '#f0932b', '#eb4d4b']) plt.xlabel('Toxicity Type', fontsize=12, fontweight='bold') plt.ylabel('Count', fontsize=12, fontweight='bold') plt.title('📊 Label Distribution in Training Dataset', fontsize=14, fontweight='bold', pad=20) plt.xticks(rotation=45, ha='right') # Add value labels on bars for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{int(height):,}', ha='center', va='bottom', fontsize=10) plt.tight_layout() return plt.gcf() def create_wordcloud(text: str, title: str, colors: str, width: int = 800, height: int = 400): """Create a word cloud from text""" # Clean text cleaned_text = clean_text_for_wordcloud(text) # Create word cloud wordcloud = WordCloud( width=width, height=height, background_color='white', colormap=colors, max_words=100, prefer_horizontal=0.7, relative_scaling=0.5, min_font_size=10 ).generate(cleaned_text) # Plot plt.figure(figsize=(12, 6)) plt.imshow(wordcloud, interpolation='bilinear') plt.axis('off') plt.title(title, fontsize=14, fontweight='bold', pad=20) plt.tight_layout() return plt.gcf() def create_toxicity_comparison_chart(df: pd.DataFrame): """Create a pie chart showing toxic vs non-toxic distribution""" label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] df['any_toxic'] = df[label_columns].max(axis=1) toxic_count = df['any_toxic'].sum() non_toxic_count = len(df) - toxic_count plt.figure(figsize=(8, 8)) colors = ['#95e1d3', '#f38181'] explode = (0.05, 0.05) plt.pie( [non_toxic_count, toxic_count], labels=['Non-Toxic', 'Toxic'], autopct='%1.1f%%', startangle=90, colors=colors, explode=explode, shadow=True, textprops={'fontsize': 14, 'fontweight': 'bold'} ) plt.title('🧩 Toxic vs Non-Toxic Comments', fontsize=16, fontweight='bold', pad=20) plt.tight_layout() return plt.gcf() def create_overlap_heatmap(df: pd.DataFrame): """Create a heatmap showing label overlaps""" label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] # Calculate pairwise overlaps overlap_matrix = np.zeros((len(label_columns), len(label_columns))) for i, label1 in enumerate(label_columns): for j, label2 in enumerate(label_columns): overlap = ((df[label1] == 1) & (df[label2] == 1)).sum() overlap_matrix[i, j] = overlap # Create heatmap plt.figure(figsize=(10, 8)) mask = np.triu(np.ones_like(overlap_matrix, dtype=bool), k=1) sns.heatmap( overlap_matrix, annot=True, fmt='.0f', cmap='YlOrRd', xticklabels=label_columns, yticklabels=label_columns, square=True, cbar_kws={"shrink": 0.8}, mask=mask, linewidths=0.5 ) plt.title('🔥 Label Co-occurrence Heatmap', fontsize=14, fontweight='bold', pad=20) plt.tight_layout() return plt.gcf() def main_visualization(file_path: str = 'train.csv'): """Main function to generate all visualizations""" # Load data df = load_dataset(file_path) if df is None: return None, None, None, None, None # Prepare data toxic_texts, non_toxic_texts, label_counts = prepare_data(df) # Create visualizations fig1 = create_label_frequency_chart(label_counts) # Create word clouds fig2 = create_wordcloud(toxic_texts, "🔴 Most Common Words in Toxic Comments", 'Reds') fig3 = create_wordcloud(non_toxic_texts, "🟢 Most Common Words in Non-Toxic Comments", 'Greens') # Create pie chart fig4 = create_toxicity_comparison_chart(df) # Create heatmap fig5 = create_overlap_heatmap(df) return fig1, fig2, fig3, fig4, fig5 # Streamlit-specific functions @st.cache_data def load_data_cached(file_path: str): """Cached version of load_dataset for Streamlit""" return load_dataset(file_path) @st.cache_data def generate_wordcloud_cached(text: str, colors: str, width: int = 800, height: int = 400): """Cached wordcloud generation""" cleaned_text = clean_text_for_wordcloud(text) wordcloud = WordCloud( width=width, height=height, background_color='white', colormap=colors, max_words=100, prefer_horizontal=0.7, relative_scaling=0.5, min_font_size=10 ).generate(cleaned_text) return wordcloud