open-model-evolution / graphs /model_market_share.py
emsesc's picture
initial charts
fd38574
raw
history blame
4.13 kB
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def create_plotly_stacked_area_chart(
model_topk_df,
model_gini_df,
model_hhi_df,
TEMP_MODEL_EVENTS,
PALETTE_0
):
"""
Convert the visualization_util stacked area chart to Plotly
"""
# Create subplot with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Define metric order
metric_order = ['Top 1', 'Top 1 - 10', 'Top 10 - 100', 'Top 100 - 1000', 'Top 1000 - 10000', 'Rest']
# Get unique time periods
time_periods = sorted(model_topk_df['time'].unique())
# Create stacked area traces
for i, metric in enumerate(metric_order):
metric_data = model_topk_df[model_topk_df['metric'] == metric]
# Sort by time and get values
metric_data = metric_data.sort_values('time')
x_vals = metric_data['time']
y_vals = metric_data['value']
# Add area trace
fig.add_trace(
go.Scatter(
x=x_vals,
y=y_vals,
name=metric,
mode='lines',
line=dict(width=0, color=PALETTE_0[i % len(PALETTE_0)]),
fill='tonexty' if i > 0 else 'tozeroy',
fillcolor=PALETTE_0[i % len(PALETTE_0)], # Add opacity
stackgroup='one',
hovertemplate='<b>%{fullData.name}</b><br>' +
'Time: %{x}<br>' +
'Value: %{y}<extra></extra>'
),
secondary_y=False
)
# Add overlay lines
# Gini Coefficient
gini_data = model_gini_df.sort_values('time')
fig.add_trace(
go.Scatter(
x=gini_data['time'],
y=gini_data['value'],
name='Gini Coefficient',
mode='lines',
line=dict(color='#6b46c1', width=3),
yaxis='y2',
hovertemplate='<b>Gini Coefficient</b><br>' +
'Time: %{x}<br>' +
'Value: %{y:.3f}<extra></extra>'
),
secondary_y=True
)
# HHI (×10)
hhi_data = model_hhi_df.sort_values('time')
fig.add_trace(
go.Scatter(
x=hhi_data['time'],
y=hhi_data['value'] * 10, # Multiply by 10 as indicated
name='HHI (×10)',
mode='lines',
line=dict(color='#ec4899', width=3),
yaxis='y2',
hovertemplate='<b>HHI (×10)</b><br>' +
'Time: %{x}<br>' +
'Value: %{y:.3f}<extra></extra>'
),
secondary_y=True
)
# Add vertical lines for events
for event_name, event_date in TEMP_MODEL_EVENTS.items():
fig.add_shape(
type="line",
x0=event_date, x1=event_date,
y0=0, y1=1,
yref="paper",
line=dict(color='#333333', width=2, dash='dash')
)
# Add annotation for the event
fig.add_annotation(
x=event_date,
y=1,
yref="paper",
text=event_name,
showarrow=False,
yshift=10,
font=dict(size=12)
)
# Update layout
fig.update_layout(
width=1000,
height=200,
font_family="Inter",
font_size=14,
showlegend=False, # Set to True if you want to show legend
margin=dict(l=60, r=60, t=40, b=60),
plot_bgcolor='white',
hovermode='x unified'
)
# Update x-axis
fig.update_xaxes(
title_text="",
showgrid=True,
gridcolor='lightgray',
gridwidth=1
)
# Update primary y-axis (left)
fig.update_yaxes(
title_text="Model Market Share",
showgrid=True,
gridcolor='lightgray',
gridwidth=1,
secondary_y=False
)
# Update secondary y-axis (right)
fig.update_yaxes(
title_text="Concentration Indices",
showgrid=False,
secondary_y=True
)
return fig