ariG23498 HF Staff commited on
Commit
b9815de
·
verified ·
1 Parent(s): baabb64

Create moe-in-transformers.py

Browse files
Files changed (1) hide show
  1. moe-in-transformers.py +316 -0
moe-in-transformers.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analyze model introductions in the Transformers repo over the last ~2 years and
3
+ classify each introduced model as "moe" vs "dense" using a heuristic regex.
4
+
5
+ Outputs (in ./moe_dense_analysis):
6
+ - moe_dense_models_raw.csv : all models + inferred intro date + moe/dense label
7
+ - moe_dense_models_2y_window.csv : only models introduced in the last ~2 years
8
+ - moe_dense_2y_timeline.csv : monthly cumulative counts (moe/dense/total) over the window
9
+ - moe_dense_2y_timeline.png : plot of cumulative counts
10
+ """
11
+
12
+ # -----------------------------------------------------------------------------
13
+ # Imports
14
+ # -----------------------------------------------------------------------------
15
+ import calendar
16
+ import csv
17
+ import datetime as dt
18
+ import re
19
+ import subprocess
20
+ from pathlib import Path
21
+
22
+ import matplotlib
23
+ matplotlib.use("Agg") # headless backend for saving figures on CI/servers
24
+ import matplotlib.dates as mdates
25
+ import matplotlib.pyplot as plt
26
+
27
+
28
+ # -----------------------------------------------------------------------------
29
+ # Repo paths / output directory
30
+ # -----------------------------------------------------------------------------
31
+ repo = Path(".").resolve()
32
+
33
+ models_root = repo / "src/transformers/models"
34
+ if not models_root.exists():
35
+ raise SystemExit("Run this from the transformers repo root.")
36
+
37
+ out_dir = repo / "moe_dense_analysis"
38
+ out_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+
41
+ # -----------------------------------------------------------------------------
42
+ # Date window: last ~2 years from "today"
43
+ # -----------------------------------------------------------------------------
44
+ today = dt.date.today()
45
+
46
+ # Handle Feb 29 gracefully when subtracting years
47
+ try:
48
+ start_date = today.replace(year=today.year - 2)
49
+ except ValueError:
50
+ # If today is Feb 29 and (today.year - 2) is not a leap year, fallback to Feb 28
51
+ start_date = today.replace(year=today.year - 2, day=28)
52
+
53
+ end_date = today
54
+
55
+
56
+ # -----------------------------------------------------------------------------
57
+ # Discover model directories
58
+ #
59
+ # We consider a directory to be a "model" if it contains modeling_<name>.py
60
+ # (e.g. src/transformers/models/llama/modeling_llama.py)
61
+ # -----------------------------------------------------------------------------
62
+ model_names = []
63
+ for model_dir in sorted(models_root.iterdir()):
64
+ if not model_dir.is_dir():
65
+ continue
66
+ if model_dir.name.startswith("__"):
67
+ continue
68
+
69
+ modeling_file = model_dir / f"modeling_{model_dir.name}.py"
70
+ if modeling_file.exists():
71
+ model_names.append(model_dir.name)
72
+
73
+ model_name_set = set(model_names)
74
+
75
+
76
+ # -----------------------------------------------------------------------------
77
+ # Infer intro date per model using git:
78
+ #
79
+ # We use git log restricted to "added files" under src/transformers/models, and
80
+ # record the earliest date where any file under that model directory was added.
81
+ #
82
+ # NOTE: This is a heuristic, not a perfect "model introduced" definition.
83
+ # -----------------------------------------------------------------------------
84
+ git_out = subprocess.run(
85
+ [
86
+ "git",
87
+ "log",
88
+ "--diff-filter=A", # only "added file" changes
89
+ "--name-only", # list file paths
90
+ "--format=DATE %ad", # insert a marker line with the commit date
91
+ "--date=short", # YYYY-MM-DD
92
+ "--",
93
+ "src/transformers/models",
94
+ ],
95
+ cwd=repo,
96
+ check=True,
97
+ text=True,
98
+ capture_output=True,
99
+ ).stdout
100
+
101
+ intro_dates = {} # model_name -> earliest YYYY-MM-DD date string we observed
102
+ current_date = None # date string for the current commit chunk in git_out
103
+
104
+ for raw_line in git_out.splitlines():
105
+ line = raw_line.strip()
106
+ if not line:
107
+ continue
108
+
109
+ # Example marker: "DATE 2024-01-10"
110
+ if line.startswith("DATE "):
111
+ current_date = line.split(" ", 1)[1]
112
+ continue
113
+
114
+ # Only consider model paths after we've seen a DATE marker
115
+ if current_date is None:
116
+ continue
117
+ if not line.startswith("src/transformers/models/"):
118
+ continue
119
+
120
+ # Expected path structure:
121
+ # src/transformers/models/<model_name>/...
122
+ parts = line.split("/")
123
+ if len(parts) < 4:
124
+ continue
125
+
126
+ model_name = parts[3]
127
+ if model_name not in model_name_set:
128
+ continue
129
+
130
+ # Keep the earliest date we've seen for this model
131
+ old = intro_dates.get(model_name)
132
+ if old is None or current_date < old:
133
+ intro_dates[model_name] = current_date
134
+
135
+
136
+ # -----------------------------------------------------------------------------
137
+ # MoE heuristic:
138
+ #
139
+ # Search for class definitions in modeling_<name>.py where the class name contains
140
+ # MoE/MOE/Moe or Expert/Experts, AND subclasses nn.Module or torch.nn.Module.
141
+ #
142
+ # If we find at least one such class, label model as "moe", else "dense".
143
+ # -----------------------------------------------------------------------------
144
+ moe_class_re = re.compile(
145
+ r"^class\s+([A-Za-z0-9_]*(?:MoE|MOE|Moe|Expert|Experts)[A-Za-z0-9_]*)"
146
+ r"\s*\(\s*(?:nn|torch\.nn)\.Module\s*\)\s*:",
147
+ re.MULTILINE,
148
+ )
149
+
150
+ records = []
151
+ for model_name in model_names:
152
+ intro = intro_dates.get(model_name)
153
+ if intro is None:
154
+ # If we couldn't find an intro date, skip it (could be missing due to heuristic)
155
+ continue
156
+
157
+ modeling_file = models_root / model_name / f"modeling_{model_name}.py"
158
+ text = modeling_file.read_text(encoding="utf-8", errors="ignore")
159
+
160
+ matches = sorted(set(moe_class_re.findall(text)))
161
+ label = "moe" if matches else "dense"
162
+
163
+ records.append(
164
+ {
165
+ "model": model_name,
166
+ "introduced_date": intro, # YYYY-MM-DD (string)
167
+ "is_moe": label, # "moe" or "dense"
168
+ "moe_class_matches": ";".join(matches), # matched class names, if any
169
+ "modeling_file": str(modeling_file.relative_to(repo)),
170
+ }
171
+ )
172
+
173
+ # Sort by intro date then name for stable outputs
174
+ records.sort(key=lambda row: (row["introduced_date"], row["model"]))
175
+
176
+
177
+ # -----------------------------------------------------------------------------
178
+ # Restrict to 2-year window
179
+ # -----------------------------------------------------------------------------
180
+ window_records = []
181
+ for row in records:
182
+ intro_obj = dt.datetime.strptime(row["introduced_date"], "%Y-%m-%d").date()
183
+ if start_date <= intro_obj <= end_date:
184
+ row_copy = dict(row)
185
+ row_copy["intro_obj"] = intro_obj # store parsed date for comparisons
186
+ window_records.append(row_copy)
187
+
188
+ window_records.sort(key=lambda row: (row["intro_obj"], row["model"]))
189
+
190
+
191
+ # -----------------------------------------------------------------------------
192
+ # Build monthly timeline points: start_date, then each next month, ending at end_date
193
+ #
194
+ # We try to keep the day-of-month stable (e.g., the 19th of each month), but clamp
195
+ # to the last day of month if needed (e.g., Feb for day=31).
196
+ # -----------------------------------------------------------------------------
197
+ points = [start_date]
198
+ while points[-1] < end_date:
199
+ last = points[-1]
200
+
201
+ # Compute next month safely
202
+ year = last.year + (last.month // 12)
203
+ month = 1 if last.month == 12 else last.month + 1
204
+ day = min(last.day, calendar.monthrange(year, month)[1])
205
+
206
+ next_month = dt.date(year, month, day)
207
+
208
+ if next_month > end_date:
209
+ break
210
+ points.append(next_month)
211
+
212
+ # Ensure the last point is exactly end_date
213
+ if points[-1] != end_date:
214
+ points.append(end_date)
215
+
216
+
217
+ # -----------------------------------------------------------------------------
218
+ # Compute cumulative counts at each timeline point
219
+ # -----------------------------------------------------------------------------
220
+ timeline_rows = []
221
+ for point in points:
222
+ moe_cum = sum(
223
+ 1
224
+ for row in window_records
225
+ if row["is_moe"] == "moe" and row["intro_obj"] <= point
226
+ )
227
+ dense_cum = sum(
228
+ 1
229
+ for row in window_records
230
+ if row["is_moe"] == "dense" and row["intro_obj"] <= point
231
+ )
232
+
233
+ timeline_rows.append(
234
+ {
235
+ "date": point.isoformat(),
236
+ "moe_cumulative": moe_cum,
237
+ "dense_cumulative": dense_cum,
238
+ "total_cumulative": moe_cum + dense_cum,
239
+ }
240
+ )
241
+
242
+
243
+ # -----------------------------------------------------------------------------
244
+ # Write CSV outputs
245
+ # -----------------------------------------------------------------------------
246
+ raw_csv = out_dir / "moe_dense_models_raw.csv"
247
+ with raw_csv.open("w", newline="", encoding="utf-8") as f:
248
+ writer = csv.DictWriter(
249
+ f,
250
+ fieldnames=["model", "introduced_date", "is_moe", "moe_class_matches", "modeling_file"],
251
+ )
252
+ writer.writeheader()
253
+ writer.writerows(records)
254
+
255
+ window_csv = out_dir / "moe_dense_models_2y_window.csv"
256
+ with window_csv.open("w", newline="", encoding="utf-8") as f:
257
+ writer = csv.DictWriter(
258
+ f,
259
+ fieldnames=["model", "introduced_date", "is_moe", "moe_class_matches", "modeling_file"],
260
+ )
261
+ writer.writeheader()
262
+ for row in window_records:
263
+ copy_row = dict(row)
264
+ copy_row.pop("intro_obj", None) # internal-only field
265
+ writer.writerow(copy_row)
266
+
267
+ timeline_csv = out_dir / "moe_dense_2y_timeline.csv"
268
+ with timeline_csv.open("w", newline="", encoding="utf-8") as f:
269
+ writer = csv.DictWriter(
270
+ f,
271
+ fieldnames=["date", "moe_cumulative", "dense_cumulative", "total_cumulative"],
272
+ )
273
+ writer.writeheader()
274
+ writer.writerows(timeline_rows)
275
+
276
+
277
+ # -----------------------------------------------------------------------------
278
+ # Plot cumulative counts over time
279
+ # -----------------------------------------------------------------------------
280
+ x = [dt.datetime.strptime(row["date"], "%Y-%m-%d").date() for row in timeline_rows]
281
+ # y_dense = [row["dense_cumulative"] for row in timeline_rows]
282
+ y_moe = [row["moe_cumulative"] for row in timeline_rows]
283
+
284
+ plt.figure(figsize=(11, 6))
285
+ # plt.plot(x, y_dense, label="Dense cumulative", linewidth=2.2)
286
+ plt.plot(x, y_moe, label="MoE cumulative", linewidth=2.2)
287
+
288
+ # plt.title(f"MoE vs Dense model introductions ({start_date} to {end_date})")
289
+ plt.title(f"MoE model introductions ({start_date} to {end_date})")
290
+ plt.xlabel("Date")
291
+ plt.ylabel("Model count")
292
+ plt.grid(alpha=0.3)
293
+ plt.legend()
294
+
295
+ ax = plt.gca()
296
+ ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
297
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
298
+ plt.xticks(rotation=45, ha="right")
299
+ plt.tight_layout()
300
+
301
+ plot_png = out_dir / "moe_dense_2y_timeline.png"
302
+ plt.savefig(plot_png, dpi=180)
303
+
304
+
305
+ # -----------------------------------------------------------------------------
306
+ # Print summary
307
+ # -----------------------------------------------------------------------------
308
+ dense_total = sum(1 for row in window_records if row["is_moe"] == "dense")
309
+ moe_total = sum(1 for row in window_records if row["is_moe"] == "moe")
310
+
311
+ print(f"Window: {start_date} -> {end_date}")
312
+ print(f"Introduced in window: dense={dense_total}, moe={moe_total}, total={dense_total + moe_total}")
313
+ print(f"Wrote {raw_csv}")
314
+ print(f"Wrote {window_csv}")
315
+ print(f"Wrote {timeline_csv}")
316
+ print(f"Wrote {plot_png}")