Fix corner case when m is too small
Browse files- adaptive_schedule.py +59 -12
adaptive_schedule.py
CHANGED
|
@@ -345,10 +345,7 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
| 345 |
def squeeze_without_change_order(schedules, m):
|
| 346 |
p = len(schedules)
|
| 347 |
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
|
| 348 |
-
max_len =
|
| 349 |
-
for seq in squeezed:
|
| 350 |
-
assert max_len == 0 or max_len == len(seq)
|
| 351 |
-
max_len = max(max_len, len(seq))
|
| 352 |
|
| 353 |
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
| 354 |
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
|
|
@@ -389,6 +386,9 @@ def squeeze_without_change_order(schedules, m):
|
|
| 389 |
identifier_cnt[i][identifier] = _cnt + 1
|
| 390 |
identifier_index[_cnt * p + i][identifier] = index
|
| 391 |
stage_index[i] = index + 1
|
|
|
|
|
|
|
|
|
|
| 392 |
return squeezed
|
| 393 |
|
| 394 |
|
|
@@ -454,6 +454,7 @@ def process_cooldown(schedules, m):
|
|
| 454 |
schedules[i][index] = 'B'
|
| 455 |
|
| 456 |
# 2: add W back in cooldown phase
|
|
|
|
| 457 |
for i in range(p):
|
| 458 |
c_w, c_ww = 0, 0
|
| 459 |
last_w_index = -1
|
|
@@ -478,12 +479,57 @@ def process_cooldown(schedules, m):
|
|
| 478 |
elif c_ww > 0:
|
| 479 |
schedules[i][j] = 'w'
|
| 480 |
c_ww -= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
schedules = squeeze_without_change_order(schedules, m)
|
| 483 |
return schedules
|
| 484 |
|
| 485 |
|
| 486 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
"""
|
| 488 |
We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
|
| 489 |
find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
|
|
@@ -491,17 +537,15 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
|
|
| 491 |
to the vacant cell, and the bubble is filled.
|
| 492 |
"""
|
| 493 |
p = len(schedules)
|
| 494 |
-
max_len = 0
|
| 495 |
-
for seq in schedules:
|
| 496 |
-
assert max_len == 0 or max_len == len(seq)
|
| 497 |
-
max_len = max(max_len, len(seq))
|
| 498 |
if starting_index is not None:
|
| 499 |
assert isinstance(starting_index, list) and len(starting_index) == p
|
| 500 |
if ending_index is not None:
|
| 501 |
assert isinstance(ending_index, list) and len(ending_index) == p
|
|
|
|
|
|
|
|
|
|
| 502 |
starting_index = starting_index or [0] * p
|
| 503 |
ending_index = ending_index or [max_len] * p
|
| 504 |
-
|
| 505 |
last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
|
| 506 |
for i in range(p):
|
| 507 |
for j in range(max_len):
|
|
@@ -510,7 +554,6 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
|
|
| 510 |
continue
|
| 511 |
last_index[i][identifier] = j
|
| 512 |
|
| 513 |
-
peak_mem = get_peak_mem(schedules)
|
| 514 |
stage_mem = [0] * p
|
| 515 |
def update_mem(stage_i, pass_c):
|
| 516 |
if pass_c in "Ff":
|
|
@@ -645,6 +688,7 @@ def check_correctness(schedules, m, raise_exception=False):
|
|
| 645 |
return False
|
| 646 |
return True
|
| 647 |
|
|
|
|
| 648 |
def relabel_w(schedules, m):
|
| 649 |
p = len(schedules)
|
| 650 |
c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
|
@@ -654,7 +698,7 @@ def relabel_w(schedules, m):
|
|
| 654 |
continue
|
| 655 |
c_cnt[i][schedules[i][j]] += 1
|
| 656 |
for c in "FfBbWw":
|
| 657 |
-
assert c_cnt[i][c] == m
|
| 658 |
for i in range(p):
|
| 659 |
w_queue = deque(maxlen=2 * m)
|
| 660 |
for j in range(len(schedules[i])):
|
|
@@ -722,6 +766,8 @@ def schedule_by_building_block(p, m, building_block, max_mem, keep_stable_phase=
|
|
| 722 |
if m < redundant_m:
|
| 723 |
# 4. remove redundancy
|
| 724 |
schedules = remove_redundancy(schedules, m)
|
|
|
|
|
|
|
| 725 |
schedules = squeeze_without_change_order(schedules, m)
|
| 726 |
print_schedules(schedules, "after removing redundancy")
|
| 727 |
init_peak_mem = peak_mem = get_peak_mem(schedules)
|
|
@@ -820,6 +866,7 @@ def schedule(p, m, cost, max_mem):
|
|
| 820 |
[4, -1, 4, -1],
|
| 821 |
[5, -1, 5, -1]
|
| 822 |
]
|
|
|
|
| 823 |
|
| 824 |
best_schedule = None
|
| 825 |
best_bubble = None
|
|
|
|
| 345 |
def squeeze_without_change_order(schedules, m):
|
| 346 |
p = len(schedules)
|
| 347 |
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
|
| 348 |
+
max_len = check_and_get_schedule_len(schedules)
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
| 351 |
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
|
|
|
|
| 386 |
identifier_cnt[i][identifier] = _cnt + 1
|
| 387 |
identifier_index[_cnt * p + i][identifier] = index
|
| 388 |
stage_index[i] = index + 1
|
| 389 |
+
new_len = max(stage_index)
|
| 390 |
+
for i in range(p):
|
| 391 |
+
squeezed[i] = squeezed[i][:new_len]
|
| 392 |
return squeezed
|
| 393 |
|
| 394 |
|
|
|
|
| 454 |
schedules[i][index] = 'B'
|
| 455 |
|
| 456 |
# 2: add W back in cooldown phase
|
| 457 |
+
max_len = 0
|
| 458 |
for i in range(p):
|
| 459 |
c_w, c_ww = 0, 0
|
| 460 |
last_w_index = -1
|
|
|
|
| 479 |
elif c_ww > 0:
|
| 480 |
schedules[i][j] = 'w'
|
| 481 |
c_ww -= 1
|
| 482 |
+
for _ in range(c_w):
|
| 483 |
+
schedules[i].append('W')
|
| 484 |
+
for _ in range(c_ww):
|
| 485 |
+
schedules[i].append('w')
|
| 486 |
+
max_len = max(max_len, len(schedules[i]))
|
| 487 |
+
for i in range(p):
|
| 488 |
+
for _ in range(len(schedules[i]), max_len):
|
| 489 |
+
schedules[i].append(' ')
|
| 490 |
|
| 491 |
schedules = squeeze_without_change_order(schedules, m)
|
| 492 |
return schedules
|
| 493 |
|
| 494 |
|
| 495 |
+
def check_and_get_schedule_len(schedules):
|
| 496 |
+
max_len = 0
|
| 497 |
+
for seq in schedules:
|
| 498 |
+
assert max_len == 0 or max_len == len(seq)
|
| 499 |
+
max_len = max(max_len, len(seq))
|
| 500 |
+
return max_len
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def release_w_in_warmup_if_under_memory(schedules, peak_mem = None):
|
| 504 |
+
"""
|
| 505 |
+
FF fBWfBW bwbw -> FF fBfBWW bwbw
|
| 506 |
+
FF f fBW BW bwbw -> FF f fBWBW bwbw
|
| 507 |
+
FF f f BW BbWbww -> FF f f BWBbWbww
|
| 508 |
+
FfFf BbWBbwWw -> FfFf BbBbWwWw
|
| 509 |
+
When the number of micro-batches is too small (than mem), the warmup phase is not optimal. We simply remove some
|
| 510 |
+
preceding W to fully utilize the memory to reduce unnecessary bubbles.
|
| 511 |
+
"""
|
| 512 |
+
p = len(schedules)
|
| 513 |
+
max_len = check_and_get_schedule_len(schedules)
|
| 514 |
+
all_peak_mem = get_peak_mem(schedules, return_all=True)
|
| 515 |
+
peak_mem = peak_mem or max(all_peak_mem)
|
| 516 |
+
min_peak = min(all_peak_mem)
|
| 517 |
+
for i in range(p):
|
| 518 |
+
cnt = 0
|
| 519 |
+
padding = [" "] * (peak_mem - min_peak)
|
| 520 |
+
for j in range(max_len):
|
| 521 |
+
if all_peak_mem[i] + cnt >= peak_mem:
|
| 522 |
+
break
|
| 523 |
+
if schedules[i][j] in "Ww":
|
| 524 |
+
padding[cnt] = schedules[i][j]
|
| 525 |
+
schedules[i][j] = ' '
|
| 526 |
+
cnt += 1
|
| 527 |
+
schedules[i].extend(padding)
|
| 528 |
+
# max_len += peak_mem - min_peak
|
| 529 |
+
return schedules
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index = None, ending_index = None, peak_mem = None):
|
| 533 |
"""
|
| 534 |
We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
|
| 535 |
find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
|
|
|
|
| 537 |
to the vacant cell, and the bubble is filled.
|
| 538 |
"""
|
| 539 |
p = len(schedules)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
if starting_index is not None:
|
| 541 |
assert isinstance(starting_index, list) and len(starting_index) == p
|
| 542 |
if ending_index is not None:
|
| 543 |
assert isinstance(ending_index, list) and len(ending_index) == p
|
| 544 |
+
|
| 545 |
+
peak_mem = peak_mem or get_peak_mem(schedules)
|
| 546 |
+
max_len = check_and_get_schedule_len(schedules)
|
| 547 |
starting_index = starting_index or [0] * p
|
| 548 |
ending_index = ending_index or [max_len] * p
|
|
|
|
| 549 |
last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
|
| 550 |
for i in range(p):
|
| 551 |
for j in range(max_len):
|
|
|
|
| 554 |
continue
|
| 555 |
last_index[i][identifier] = j
|
| 556 |
|
|
|
|
| 557 |
stage_mem = [0] * p
|
| 558 |
def update_mem(stage_i, pass_c):
|
| 559 |
if pass_c in "Ff":
|
|
|
|
| 688 |
return False
|
| 689 |
return True
|
| 690 |
|
| 691 |
+
|
| 692 |
def relabel_w(schedules, m):
|
| 693 |
p = len(schedules)
|
| 694 |
c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
|
|
|
| 698 |
continue
|
| 699 |
c_cnt[i][schedules[i][j]] += 1
|
| 700 |
for c in "FfBbWw":
|
| 701 |
+
assert c_cnt[i][c] == m, f"{i}, {c}, {c_cnt[i][c]}"
|
| 702 |
for i in range(p):
|
| 703 |
w_queue = deque(maxlen=2 * m)
|
| 704 |
for j in range(len(schedules[i])):
|
|
|
|
| 766 |
if m < redundant_m:
|
| 767 |
# 4. remove redundancy
|
| 768 |
schedules = remove_redundancy(schedules, m)
|
| 769 |
+
if m <= p and 2 * m <= max_mem:
|
| 770 |
+
schedules = release_w_in_warmup_if_under_memory(schedules, peak_mem=min(2 * p, peak_mem))
|
| 771 |
schedules = squeeze_without_change_order(schedules, m)
|
| 772 |
print_schedules(schedules, "after removing redundancy")
|
| 773 |
init_peak_mem = peak_mem = get_peak_mem(schedules)
|
|
|
|
| 866 |
[4, -1, 4, -1],
|
| 867 |
[5, -1, 5, -1]
|
| 868 |
]
|
| 869 |
+
# available_starting_patterns = available_starting_patterns[:1]
|
| 870 |
|
| 871 |
best_schedule = None
|
| 872 |
best_bubble = None
|