import plotly.graph_objects as go from plotly.subplots import make_subplots def create_plotly_stacked_area_chart( model_topk_df, model_gini_df, model_hhi_df, TEMP_MODEL_EVENTS, PALETTE_0 ): """ Convert the visualization_util stacked area chart to Plotly """ # Create subplot with secondary y-axis fig = make_subplots(specs=[[{"secondary_y": True}]]) # Define metric order metric_order = ['Top 1', 'Top 1 - 10', 'Top 10 - 100', 'Top 100 - 1000', 'Top 1000 - 10000', 'Rest'] # Get unique time periods time_periods = sorted(model_topk_df['time'].unique()) # Create stacked area traces for i, metric in enumerate(metric_order): metric_data = model_topk_df[model_topk_df['metric'] == metric] # Sort by time and get values metric_data = metric_data.sort_values('time') x_vals = metric_data['time'] y_vals = metric_data['value'] # Add area trace fig.add_trace( go.Scatter( x=x_vals, y=y_vals, name=metric, mode='lines', line=dict(width=0, color=PALETTE_0[i % len(PALETTE_0)]), fill='tonexty' if i > 0 else 'tozeroy', fillcolor=PALETTE_0[i % len(PALETTE_0)], # Add opacity stackgroup='one', hovertemplate='%{fullData.name}
' + 'Time: %{x}
' + 'Value: %{y}' ), secondary_y=False ) # Add overlay lines # Gini Coefficient gini_data = model_gini_df.sort_values('time') fig.add_trace( go.Scatter( x=gini_data['time'], y=gini_data['value'], name='Gini Coefficient', mode='lines', line=dict(color='#6b46c1', width=3), yaxis='y2', hovertemplate='Gini Coefficient
' + 'Time: %{x}
' + 'Value: %{y:.3f}' ), secondary_y=True ) # HHI (×10) hhi_data = model_hhi_df.sort_values('time') fig.add_trace( go.Scatter( x=hhi_data['time'], y=hhi_data['value'] * 10, # Multiply by 10 as indicated name='HHI (×10)', mode='lines', line=dict(color='#ec4899', width=3), yaxis='y2', hovertemplate='HHI (×10)
' + 'Time: %{x}
' + 'Value: %{y:.3f}' ), secondary_y=True ) # Add vertical lines for events for event_name, event_date in TEMP_MODEL_EVENTS.items(): fig.add_shape( type="line", x0=event_date, x1=event_date, y0=0, y1=1, yref="paper", line=dict(color='#333333', width=2, dash='dash') ) # Add annotation for the event fig.add_annotation( x=event_date, y=1, yref="paper", text=event_name, showarrow=False, yshift=10, font=dict(size=12) ) # Update layout fig.update_layout( width=1000, height=200, font_family="Inter", font_size=14, showlegend=False, # Set to True if you want to show legend margin=dict(l=60, r=60, t=40, b=60), plot_bgcolor='white', hovermode='x unified' ) # Update x-axis fig.update_xaxes( title_text="", showgrid=True, gridcolor='lightgray', gridwidth=1 ) # Update primary y-axis (left) fig.update_yaxes( title_text="Model Market Share", showgrid=True, gridcolor='lightgray', gridwidth=1, secondary_y=False ) # Update secondary y-axis (right) fig.update_yaxes( title_text="Concentration Indices", showgrid=False, secondary_y=True ) return fig