File size: 35,208 Bytes
1e407f0
 
 
 
 
 
 
 
 
aeb3812
1e407f0
867401e
 
 
 
1e407f0
 
 
 
5ce4c31
 
 
867401e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeb3812
1e407f0
 
 
 
 
 
 
 
 
 
 
867401e
1e407f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867401e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
 
 
 
 
 
 
 
 
 
 
 
 
867401e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
 
 
 
 
 
 
aeb3812
 
1e407f0
aeb3812
 
 
 
867401e
aeb3812
 
 
1e407f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "grouped_gemm.h"

#ifdef __HIP_PLATFORM_AMD__

#include "gpu_backend_hip.h"
#include <ATen/hip/HIPContext.h>
#include <hipblaslt/hipblaslt.h>
#include <torch/autograd.h>
#include <vector>
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <string>

namespace grouped_gemm {
namespace {

// Experimental: toggled via MEGABLOCKS_GG_USE_HIPBLASLT=1. This flag is
// intentionally off by default because the hipBLASLt path still fails on the
// largest `tests/ops_test.py` configurations.
bool use_hipblaslt_backend() {
  static int cached = [] {
    const char* raw = std::getenv("MEGABLOCKS_GG_USE_HIPBLASLT");
    if (raw == nullptr) {
      return 0;
    }
    std::string value(raw);
    std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
      return static_cast<char>(std::tolower(c));
    });
    if (value == "1" || value == "true" || value == "yes" || value == "on") {
      return 1;
    }
    return 0;
  }();
  return cached == 1;
}

inline void hipblaslt_check(hipblasStatus_t status, const char* expr) {
  TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "hipBLASLt call failed with status ", status, " when executing ", expr);
}
#define HIPBLASLT_CHECK(cmd) hipblaslt_check((cmd), #cmd)

hipblasLtHandle_t hipblaslt_handle() {
  static hipblasLtHandle_t handle = [] {
    hipblasLtHandle_t h;
    HIPBLASLT_CHECK(hipblasLtCreate(&h));
    return h;
  }();
  return handle;
}

void hipblaslt_run_matmul(const void* a_ptr,
                          const void* b_ptr,
                          const void* c_ptr,
                          void* d_ptr,
                          int64_t rows_a,
                          int64_t cols_a,
                          int64_t rows_b,
                          int64_t cols_b,
                          int64_t rows_d,
                          int64_t cols_d,
                          int64_t lda,
                          int64_t ldb,
                          int64_t ldc,
                          int64_t ldd,
                          hipblasOperation_t op_a,
                          hipblasOperation_t op_b,
                          bool accumulate) {
  if (rows_a == 0 || cols_a == 0 || rows_b == 0 || cols_b == 0 || rows_d == 0 || cols_d == 0)
    return;

  auto handle = hipblaslt_handle();
  auto stream = c10::hip::getCurrentHIPStream();

  hipblasLtMatmulDesc_t matmul_desc;
  HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul_desc, HIPBLAS_COMPUTE_32F, HIP_R_32F));
  HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
      matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a)));
  HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
      matmul_desc, HIPBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b)));
  hipblasLtPointerMode_t pointer_mode = HIPBLASLT_POINTER_MODE_HOST;
  HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
      matmul_desc, HIPBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode)));

  hipblasLtOrder_t order = HIPBLASLT_ORDER_ROW;

  hipblasLtMatrixLayout_t layout_a;
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_a, HIP_R_16BF, rows_a, cols_a, lda));
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
      layout_a, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));

  hipblasLtMatrixLayout_t layout_b;
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_b, HIP_R_16BF, rows_b, cols_b, ldb));
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
      layout_b, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));

  hipblasLtMatrixLayout_t layout_c;
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_c, HIP_R_16BF, rows_d, cols_d, ldc));
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
      layout_c, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));

  hipblasLtMatrixLayout_t layout_d;
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&layout_d, HIP_R_16BF, rows_d, cols_d, ldd));
  HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
      layout_d, HIPBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order)));

  hipblasLtMatmulPreference_t preference;
  HIPBLASLT_CHECK(hipblasLtMatmulPreferenceCreate(&preference));
  uint64_t workspace_size = 0;
  HIPBLASLT_CHECK(hipblasLtMatmulPreferenceSetAttribute(
      preference,
      HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
      &workspace_size,
      sizeof(workspace_size)));

  hipblasLtMatmulHeuristicResult_t heuristic{};
  int returned_results = 0;
  HIPBLASLT_CHECK(hipblasLtMatmulAlgoGetHeuristic(
      handle,
      matmul_desc,
      layout_a,
      layout_b,
      layout_c,
      layout_d,
      preference,
      1,
      &heuristic,
      &returned_results));
  TORCH_CHECK(returned_results > 0, "hipBLASLt could not find a suitable algorithm");

  const float alpha = 1.0f;
  const float beta = accumulate ? 1.0f : 0.0f;

  HIPBLASLT_CHECK(hipblasLtMatmul(handle,
                                  matmul_desc,
                                  &alpha,
                                  a_ptr,
                                  layout_a,
                                  b_ptr,
                                  layout_b,
                                  &beta,
                                  c_ptr,
                                  layout_c,
                                  d_ptr,
                                  layout_d,
                                  &heuristic.algo,
                                  nullptr,
                                  0,
                                  stream));

  hipblasLtMatmulPreferenceDestroy(preference);
  hipblasLtMatrixLayoutDestroy(layout_d);
  hipblasLtMatrixLayoutDestroy(layout_c);
  hipblasLtMatrixLayoutDestroy(layout_b);
  hipblasLtMatrixLayoutDestroy(layout_a);
  hipblasLtMatmulDescDestroy(matmul_desc);
}

}  // namespace

torch::Tensor hipblaslt_gmm_internal(torch::Tensor a,
                                      torch::Tensor b,
                                      torch::Tensor batch_sizes,
                                      bool trans_a,
                                      bool trans_b,
                                      c10::optional<torch::Tensor> c_opt) {
  torch::NoGradGuard no_grad;
  TORCH_CHECK(a.is_cuda(), "hipblaslt_gmm requires CUDA tensors");
  TORCH_CHECK(b.is_cuda(), "hipblaslt_gmm requires CUDA tensors");
  TORCH_CHECK(a.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 inputs");
  TORCH_CHECK(b.scalar_type() == torch::kBFloat16, "hipblaslt_gmm expects BF16 weights");
  TORCH_CHECK(batch_sizes.device().is_cpu(), "batch_sizes must reside on CPU");

  a = a.contiguous();
  b = b.contiguous();

  auto device = a.device();
  auto dtype = a.scalar_type();
  const bool use_hip = use_hipblaslt_backend();

  const auto counts_ptr = batch_sizes.data_ptr<int64_t>();
  const int64_t num_experts = batch_sizes.size(0);
  std::vector<int64_t> prefix(num_experts);
  int64_t running = 0;
  for (int64_t i = 0; i < num_experts; ++i) {
    running += counts_ptr[i];
    prefix[i] = running;
  }
  const int64_t tokens = num_experts ? prefix.back() : 0;
  TORCH_CHECK(a.size(0) == tokens, "tokens mismatch with batch sizes");

  torch::Tensor out;
  if (trans_a) {
    const int64_t hidden_in = a.size(1);
    const int64_t hidden_out = b.size(1);
    out = c_opt.value_or(torch::empty({num_experts, hidden_in, hidden_out},
                                      a.options().dtype(dtype)));
    TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");

    auto b_contig = b.contiguous();

    if (use_hip) {
      int64_t start = 0;
      for (int64_t expert = 0; expert < num_experts; ++expert) {
        const int64_t end = prefix[expert];
        const int64_t rows = end - start;
        auto out_chunk = out.select(0, expert);
        if (rows == 0) {
          out_chunk.zero_();
          start = end;
          continue;
        }

        auto a_chunk = a.narrow(0, start, rows).contiguous();
        auto b_chunk = b_contig.narrow(0, start, rows).contiguous();

        hipblaslt_run_matmul(a_chunk.data_ptr(),
                             b_chunk.data_ptr(),
                             out_chunk.data_ptr(),
                             out_chunk.data_ptr(),
                             rows,
                             hidden_in,
                             rows,
                             hidden_out,
                             hidden_in,
                             hidden_out,
                             hidden_in,
                             hidden_out,
                             hidden_out,
                             hidden_out,
                             HIPBLAS_OP_T,
                             HIPBLAS_OP_N,
                             /*accumulate=*/false);
        start = end;
      }
    } else {
      int64_t start = 0;
      for (int64_t expert = 0; expert < num_experts; ++expert) {
        const int64_t end = prefix[expert];
        const int64_t rows = end - start;
        auto out_chunk = out.select(0, expert);
        if (rows == 0) {
          out_chunk.zero_();
          start = end;
          continue;
        }

        auto a_slice = a.narrow(0, start, rows);
        auto b_slice = b_contig.narrow(0, start, rows);

        auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
        auto b_f32 = b_slice.contiguous().to(torch::kFloat32);

        auto prod = torch::matmul(a_f32.transpose(0, 1), b_f32);
        auto prod_bf16 = prod.to(dtype);

        out_chunk.copy_(prod_bf16);
        start = end;
      }
    }
    return out;
  }

  if (trans_b) {
    const int64_t hidden_in = a.size(1);
    const int64_t hidden_out = b.size(1);
    out = c_opt.value_or(torch::empty({tokens, hidden_out}, a.options()));
    TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");

    auto b_contig = b.contiguous();

    if (use_hip) {
      int64_t start = 0;
      for (int64_t expert = 0; expert < num_experts; ++expert) {
        const int64_t end = prefix[expert];
        const int64_t rows = end - start;
        if (rows == 0) {
          start = end;
          continue;
        }
        auto a_chunk = a.narrow(0, start, rows).contiguous();
        auto b_chunk = b_contig.select(0, expert).contiguous();
        auto out_chunk = out.narrow(0, start, rows);

        hipblaslt_run_matmul(a_chunk.data_ptr(),
                             b_chunk.data_ptr(),
                             out_chunk.data_ptr(),
                             out_chunk.data_ptr(),
                             rows,
                             hidden_in,
                             hidden_out,
                             hidden_in,
                             rows,
                             hidden_out,
                             hidden_in,
                             hidden_in,
                             hidden_out,
                             hidden_out,
                             HIPBLAS_OP_N,
                             HIPBLAS_OP_T,
                             /*accumulate=*/false);
        start = end;
      }
    } else {
      int64_t start = 0;
      for (int64_t expert = 0; expert < num_experts; ++expert) {
        const int64_t end = prefix[expert];
        const int64_t rows = end - start;
        if (rows == 0) {
          start = end;
          continue;
        }
        auto a_slice = a.narrow(0, start, rows);
        auto b_slice = b_contig.select(0, expert);
        auto out_chunk = out.narrow(0, start, rows);

        auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
        auto b_f32 = b_slice.contiguous().to(torch::kFloat32);

        auto prod = torch::matmul(a_f32, b_f32.transpose(0, 1));
        auto prod_bf16 = prod.to(dtype);

        out_chunk.copy_(prod_bf16);
        start = end;
      }
    }
    return out;
  }

  const int64_t hidden_out = a.size(1);
  const int64_t hidden_in = b.size(2);
  out = c_opt.value_or(torch::empty({tokens, hidden_in}, a.options()));
  TORCH_CHECK(out.is_contiguous(), "Output tensor must be contiguous");

  auto b_contig = b.contiguous();

  if (use_hip) {
    int64_t start = 0;
    for (int64_t expert = 0; expert < num_experts; ++expert) {
      const int64_t end = prefix[expert];
      const int64_t rows = end - start;
      if (rows == 0) {
        start = end;
        continue;
      }
      auto a_chunk = a.narrow(0, start, rows).contiguous();
      auto b_chunk = b_contig.select(0, expert).contiguous();
      auto out_chunk = out.narrow(0, start, rows);

      hipblaslt_run_matmul(a_chunk.data_ptr(),
                           b_chunk.data_ptr(),
                           out_chunk.data_ptr(),
                           out_chunk.data_ptr(),
                           rows,
                           hidden_out,
                           hidden_out,
                           hidden_in,
                           rows,
                           hidden_in,
                           hidden_out,
                           hidden_in,
                           hidden_in,
                           hidden_in,
                           HIPBLAS_OP_N,
                           HIPBLAS_OP_N,
                           /*accumulate=*/false);
      start = end;
    }
  } else {
    int64_t start = 0;
    for (int64_t expert = 0; expert < num_experts; ++expert) {
      const int64_t end = prefix[expert];
      const int64_t rows = end - start;
      if (rows == 0) {
        start = end;
        continue;
      }
      auto a_slice = a.narrow(0, start, rows);
      auto b_slice = b_contig.select(0, expert);
      auto out_chunk = out.narrow(0, start, rows);

      auto a_f32 = a_slice.contiguous().to(torch::kFloat32);
      auto b_f32 = b_slice.contiguous().to(torch::kFloat32);

      auto prod = torch::matmul(a_f32, b_f32);
      auto prod_bf16 = prod.to(dtype);

      out_chunk.copy_(prod_bf16);
      start = end;
    }
  }
  return out;
}

void GroupedGemm(torch::Tensor a,
                 torch::Tensor b,
                 torch::Tensor c,
                 torch::Tensor batch_sizes,
                 bool trans_a,
                 bool trans_b) {
  if (!batch_sizes.device().is_cpu()) {
    batch_sizes = batch_sizes.cpu();
  }
  TORCH_CHECK(c.is_contiguous(), "Output tensor must be contiguous");
  auto result = hipblaslt_gmm_internal(a, b, batch_sizes, trans_a, trans_b, c);
  if (!c.is_alias_of(result)) {
    c.copy_(result);
  }
}

}  // namespace grouped_gemm

#else
#include "fill_arguments_hip.cuh"

#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/KernelUtils.h>
#include <c10/util/BFloat16.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <hipcub/hipcub.hpp>
#include <torch/torch.h>

#include "cutlass/bfloat16.h"
#include "cutlass/complex.h"
#include "cutlass/gemm/kernel/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/gemm/device/gemm_grouped.h"

#include <type_traits>

namespace grouped_gemm {

#define CUDA_CALL(code)					    \
  do {                                                      \
    hipError_t status = code;                              \
    std::string err = hipGetErrorString(status);           \
    TORCH_CHECK(status == hipSuccess, err);		    \
  } while (0)

#define CUBLAS_CALL(code)					  \
  do {								  \
    hipblasStatus_t status = code;				  \
    TORCH_CHECK(status == HIPBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
  } while (0)

#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
#define GROUPED_GEMM_STRINGIFY(x) \
  GROUPED_GEMM_STRINGIFY_HELPER(x)

template <bool trans>
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;

using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
  ::cutlass::arch::OpClassTensorOp,
  ::cutlass::arch::Sm80,
  ::cutlass::bfloat16_t,
  ::cutlass::bfloat16_t,
  ::cutlass::bfloat16_t,
  float
>;

// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
template <bool trans_a, bool trans_b>
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
  // A operand.
  ::cutlass::bfloat16_t,
  GroupedGemmInputLayout<trans_a>,
  ::cutlass::ComplexTransform::kNone,
  GroupedGemmConfig::kAlignmentA,
  // B operand.
  ::cutlass::bfloat16_t,
  GroupedGemmInputLayout<trans_b>,
  ::cutlass::ComplexTransform::kNone,
  GroupedGemmConfig::kAlignmentB,
  // C operand.
  ::cutlass::bfloat16_t,
  ::cutlass::layout::RowMajor,
  float,
  ::cutlass::arch::OpClassTensorOp,
  ::cutlass::arch::Sm80,
  GroupedGemmConfig::ThreadblockShape,
  GroupedGemmConfig::WarpShape,
  GroupedGemmConfig::InstructionShape,
  GroupedGemmConfig::EpilogueOutputOp,
  // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
  // This parameter is passed in at present to match the APIs of other kernels. The parameter
  // is unused within the kernel.
  ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
  // TODO(tgale): Tune this for SM90.
  GroupedGemmConfig::kStages>::GemmKernel;

template <bool trans_a, bool trans_b>
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;

template <typename T>
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
  size_t bytes = x.size() * sizeof(T);
  auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
  torch::Tensor out = torch::empty(bytes, options);

  CUDA_CALL(hipMemcpyAsync(out.data_ptr(),
			    x.data(), bytes,
			    hipMemcpyHostToDevice,
			    c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));
  return out;
}

template <typename T>
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
    // For now, simply create a copy of the data and then copy over to the original.
    std::vector<T> copy(data, data + indices.size());
    for (size_t i = 0; i < indices.size(); ++i) {
        data[i] = copy.at(indices[i]);
    }
}

template <typename T>
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
    return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
}

struct RawGemmArguments {
  torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
  int threadblock_count{};
};

template <
  typename Gemm,
  typename ElementA, typename ElementB, typename ElementC
>
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
    TORCH_CHECK(
        num_experts <= kMaxExperts,
        "At most ", kMaxExperts,
        " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
    );

    return RawGemmArguments {
      .lda = TypedEmpty<int64_t>(num_experts, device),
      .ldb = TypedEmpty<int64_t>(num_experts, device),
      .ldc = TypedEmpty<int64_t>(num_experts, device),
      .ptr_a = TypedEmpty<ElementA*>(num_experts, device),
      .ptr_b = TypedEmpty<ElementB*>(num_experts, device),
      .ptr_c = TypedEmpty<ElementC*>(num_experts, device),
      .problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),

      // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
      .threadblock_count = Gemm::sufficient(),
    };
}

template <
  bool kDynamicK,
  typename Gemm,
  typename ElementA, typename ElementB, typename ElementC,
  typename LayoutA, typename LayoutB, typename LayoutC
>
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
				     torch::Tensor b,
				     torch::Tensor c,
				     torch::Tensor batch_sizes,
				     ::cutlass::gemm::GemmCoord coord_template,
				     int64_t num_experts) {
  std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);

  // Create the host arrays of leading dimension data and pointer data.
  std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
  int64_t elements_a = 0, elements_b = 0, elements_c = 0;

  std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);

  for (int i = 0; i < num_experts; ++i) {
    auto& problem = problem_sizes_host[i];
    problem = coord_template;
    (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];

    lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
    ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
    ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);

    ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
    ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
    ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;

    elements_a += problem.m() * problem.k();
    elements_b += problem.k() * problem.n();
    elements_c += problem.m() * problem.n();

    if (problem.k() == 0) {
      // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
      // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
      //   * set the output to zero with `hipMemsetAsync()`
      //   * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
      CUDA_CALL(hipMemsetAsync(ptr_c_host[i],
        0,
        problem.m() * problem.n() * sizeof(ElementC),
        c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()));

      problem.m() = 0;
      problem.n() = 0;
    }
  }

  // Only sort problems when K are different
  if (kDynamicK) {
      std::vector<size_t> indices(num_experts);
      std::iota(indices.begin(), indices.end(), 0);
      std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
          return problem_sizes_host[i].k() > problem_sizes_host[j].k();
      });

      ReorderArray(problem_sizes_host.data(), indices);
      ReorderArray(lda_host.data(), indices);
      ReorderArray(ldb_host.data(), indices);
      ReorderArray(ldc_host.data(), indices);
      ReorderArray(ptr_a_host.data(), indices);
      ReorderArray(ptr_b_host.data(), indices);
      ReorderArray(ptr_c_host.data(), indices);
  }

  // Copy the problem sizes, pointers and leading dimension data to the device.
  return RawGemmArguments {
    .lda = CopyToDevice(lda_host, a.device()),
    .ldb = CopyToDevice(ldb_host, a.device()),
    .ldc = CopyToDevice(ldc_host, a.device()),
    .ptr_a = CopyToDevice(ptr_a_host, a.device()),
    .ptr_b = CopyToDevice(ptr_b_host, a.device()),
    .ptr_c = CopyToDevice(ptr_c_host, a.device()),
    .problem_sizes = CopyToDevice(problem_sizes_host, a.device()),

    // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
    .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
  };
}

template <
  bool kDynamicK,
  typename Gemm,
  typename ElementA, typename ElementB, typename ElementC,
  typename LayoutA, typename LayoutB, typename LayoutC
>
typename Gemm::Arguments MakeArguments(torch::Tensor a,
				       torch::Tensor b,
				       torch::Tensor c,
				       torch::Tensor batch_sizes,
				       ::cutlass::gemm::GemmCoord coord_template,
				       int64_t num_experts) {
  RawGemmArguments raw_args;
  if (batch_sizes.is_cuda()) {
    raw_args = MakeArgumentsOnDevice<
      Gemm, ElementA, ElementB, ElementC
    >(num_experts, a.device());
  } else {
    raw_args = MakeArgumentsOnHost<
      kDynamicK,
      Gemm,
      ElementA, ElementB, ElementC,
      LayoutA, LayoutB, LayoutC
    >(a, b, c, batch_sizes, coord_template, num_experts);
  }

  printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
  // Validate the result.
  if (!raw_args.threadblock_count) {
    TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
  }

  typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
  // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
  // so we can safely pass `nullptr` for `host_problem_sizes`.
  // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
  // know the problem dimensions on the host.
  typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
				     (int)num_experts,
				     (int)raw_args.threadblock_count,
				     epilogue_op,
				     (ElementA**)raw_args.ptr_a.data_ptr(),
				     (ElementB**)raw_args.ptr_b.data_ptr(),
				     (ElementC**)raw_args.ptr_c.data_ptr(),
				     (ElementC**)raw_args.ptr_c.data_ptr(),
				     /*lda=*/(int64_t*)raw_args.lda.data_ptr(),
				     /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
				     /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
				     /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
				     /*host_problem_sizes=*/nullptr);
  return arguments;
}

template <
  bool trans_a,
  typename ElementA, typename ElementB, typename ElementC,
  typename LayoutA, typename LayoutB, typename LayoutC,
  typename Arguments
>
void FillCutlassArguments(int num_experts,
			  torch::Tensor batch_sizes,
			  torch::Tensor a,
			  torch::Tensor b,
			  torch::Tensor c,
			  const Arguments& arguments,
			  ::cutlass::gemm::GemmCoord coord_template) {
  // Convert the batch sizes to the format CUTLASS understands on the device.
  // Use a single block here because:
  //   * the number of elements to process is microscopically small
  //   * we don't need any additional global memory
 hipLaunchKernelGGL(( FillArguments<
      /*kDynamicK*/trans_a,
      ElementA, ElementB, ElementC,
      LayoutA, LayoutB, LayoutC
  >), dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(), 
      num_experts, batch_sizes.data_ptr<int64_t>(),
      (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
      arguments, coord_template
  );
  C10_HIP_KERNEL_LAUNCH_CHECK();
}

template <typename Args>
void RemoveK0Problems(int num_experts, const Args& arguments) {
  // For zeroing out the outputs (which might be arbitrarily large), we want to use
  // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
  // `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
  // should be a good approximation for this.
  // When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
 hipLaunchKernelGGL(( ZeroOutK0Outputs<>), 
    dim3(arguments.threadblock_count), dim3(at::cuda::detail::CUDA_NUM_THREADS), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()
  , 
    num_experts, arguments
  );
 hipLaunchKernelGGL(( IgnoreK0Problems<>), 
    dim3(1), dim3(kMaxExperts), 0, c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()
  , 
    num_experts, arguments
  );
}

template <bool trans_a, bool trans_b>
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
				 torch::Tensor b,
				 torch::Tensor c,
				 torch::Tensor batch_sizes,
				 ::cutlass::gemm::GemmCoord coord_template) {
  using Gemm = GemmGrouped<trans_a, trans_b>;
  using LayoutA = typename Gemm::LayoutA;
  using LayoutB = typename Gemm::LayoutB;
  using LayoutC = typename Gemm::LayoutC;

  using ElementA = typename Gemm::ElementA;
  using ElementB = typename Gemm::ElementB;
  using ElementC = typename Gemm::ElementC;

  Gemm gemm;
  int64_t num_experts = batch_sizes.size(0);
  auto arguments = MakeArguments<
    /*kDynamicK*/trans_a,
    Gemm,
    ElementA, ElementB, ElementC,
    LayoutA, LayoutB, LayoutC
  >(a, b, c, batch_sizes, coord_template, num_experts);
  int64_t workspace_size = gemm.get_workspace_size(arguments);
  auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
  torch::Tensor workspace = torch::empty(workspace_size, options);

  if (batch_sizes.is_cuda()) {
      FillCutlassArguments<
        trans_a,
        ElementA, ElementB, ElementC,
        LayoutA, LayoutB, LayoutC
      >(num_experts, batch_sizes, a, b, c, arguments, coord_template);

      RemoveK0Problems<>(num_experts, arguments);
  }

  // Initialize the kernel.
  if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
    TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
  }

  // Execute the kernel in the current stream.
  if(gemm.run(c10::hip::getCurrentHIPStreamMasqueradingAsCUDA()) != cutlass::Status::kSuccess) {
    TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
  }
  return c;
}

void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
		c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
		c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
  int m = trans_b ? b_rows : b_cols;
  int k = trans_b ? b_cols : b_rows;
  int n = trans_a ? a_cols : a_rows;

  int lda = trans_a ? n : k;
  int ldb = trans_b ? k : m;
  hipblasOperation_t transpose_a = trans_a ? HIPBLAS_OP_T : HIPBLAS_OP_N;
  hipblasOperation_t transpose_b = trans_b ? HIPBLAS_OP_T : HIPBLAS_OP_N;

  float alpha = 1.0, beta = 0.0;
  CUBLAS_CALL(hipblasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
			   transpose_b, transpose_a,
			   m, n, k, &alpha,
			   b, HIP_R_16BF, ldb,
			   a, HIP_R_16BF, lda,
			   &beta,
			   c, HIP_R_16BF, c_cols, HIP_R_32F,
			   HIPBLAS_GEMM_DEFAULT));
}

void CublasGroupedGemm(torch::Tensor a,
		       torch::Tensor b,
		       torch::Tensor c,
		       torch::Tensor batch_sizes,
		       bool trans_b) {
  int64_t bs = batch_sizes.size(0), k = a.size(1);
  int64_t n = trans_b ? b.size(1) : b.size(2);
  int64_t b_rows = b.size(1), b_cols = b.size(2);
  c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
  c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
  c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
  for (int i = 0; i < bs; ++i) {
    int64_t m = batch_sizes.data_ptr<int64_t>()[i];
    CublasGemm(a_ptr, m, k, /*trans_a=*/false,
	       b_ptr, b_rows, b_cols, trans_b,
	       c_ptr, m, n);
    a_ptr += m * k;
    b_ptr += b_rows * b_cols;
    c_ptr += m * n;
  }
}

void CublasGroupedGemmVariableK(torch::Tensor a,
				torch::Tensor b,
				torch::Tensor c,
				torch::Tensor batch_sizes) {
  int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
  c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
  c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
  c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
  for (int i = 0; i < bs; ++i) {
    int64_t k = batch_sizes.data_ptr<int64_t>()[i];
    CublasGemm(a_ptr, k, m, /*trans_a=*/true,
	       b_ptr, k, n, /*trans_b=*/false,
	       c_ptr, m, n);
    a_ptr += k * m;
    b_ptr += k * n;
    c_ptr += m * n;
  }
}

void GroupedGemmVariableK(torch::Tensor a,
			  torch::Tensor b,
			  torch::Tensor c,
			  torch::Tensor batch_sizes) {
  // We expected a CUDA tensor with two dimensions and shape
  // (tokens, hidden_out) for 'b'.
  TORCH_CHECK(b.is_cuda());
  TORCH_CHECK(b.ndimension() == 2);
  TORCH_CHECK(b.scalar_type() == torch::kBFloat16);

  // Validate the dimensions.
  int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
  int64_t m = a.size(1), n = b.size(1);

  // Validate that we have the same contraction dimension.
  TORCH_CHECK(tokens == b.size(0));

  // Validate the output shape.
  TORCH_CHECK(c.is_cuda());
  TORCH_CHECK(c.ndimension() == 3);
  TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
  TORCH_CHECK(c.size(0) == num_experts);
  TORCH_CHECK(c.size(1) == m);
  TORCH_CHECK(c.size(2) == n);

  // Run the computation.
  CublasGroupedGemmVariableK(a, b, c, batch_sizes);
}

// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
// assumed to be batched with fixed sized batches.
//
// TODO(tgale): Validate alignment is true for every batch element.
void GroupedGemm(torch::Tensor a,
		 torch::Tensor b,
		 torch::Tensor c,
		 torch::Tensor batch_sizes,
		 bool trans_a, bool trans_b) {
  // NOTE: We only support 'trans_a' or 'trans_b', not both.
  TORCH_CHECK(!(trans_a && trans_b));

#if !defined(GROUPED_GEMM_CUTLASS)
  // No way to run cuBLAS kernels if the problem dimensions are not known on the host.
  TORCH_CHECK(batch_sizes.is_cpu());
#else
  // CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
  TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
#endif
  TORCH_CHECK(batch_sizes.ndimension() == 1);
  TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);

  // We expected a CUDA tensor with two dimensions and shape
  // (tokens, hidden_in) for 'a'.
  TORCH_CHECK(a.is_cuda());
  TORCH_CHECK(a.ndimension() == 2);
  TORCH_CHECK(a.scalar_type() == torch::kBFloat16);

#if !defined(GROUPED_GEMM_CUTLASS)
  if (trans_a) {
    // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
    // for the rest of the op.
    GroupedGemmVariableK(a, b, c, batch_sizes);
    return;
  }
#endif

  TORCH_CHECK(b.is_cuda());
  TORCH_CHECK(c.is_cuda());
  TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
  TORCH_CHECK(c.scalar_type() == torch::kBFloat16);

  // The expected shapes of 'b' and 'c' are:
  //   * when 'trans_a' is set: b=(tokens, hidden_out),                 c=(num_experts, hidden_in, hidden_out)
  //   * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
  //   * otherwise:             b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
  size_t hidden_in{}, hidden_out{};
  if (trans_a) {
    hidden_in = a.size(1);
    hidden_out = b.size(1);

    TORCH_CHECK(b.ndimension() == 2);
    TORCH_CHECK(c.ndimension() == 3);
    TORCH_CHECK(b.size(0) == a.size(0));
    TORCH_CHECK(c.size(0) == batch_sizes.size(0));
    TORCH_CHECK(c.size(1) == hidden_in);
    TORCH_CHECK(c.size(2) == hidden_out);
  } else {
    TORCH_CHECK(b.ndimension() == 3);
    TORCH_CHECK(c.ndimension() == 2);

    // Validate the contraction dimensions match.
    int64_t tokens = a.size(0), num_experts = b.size(0);
    hidden_in = trans_b ? b.size(2) : b.size(1);
    hidden_out = trans_b ? b.size(1) : b.size(2);
    TORCH_CHECK(hidden_in == a.size(1));

    // Validate that we have one size per expert.
    TORCH_CHECK(batch_sizes.size(0) == num_experts);
  }

  // NOTE: We support transposition through the 'trans_b' flag.
  TORCH_CHECK(a.is_contiguous());
  TORCH_CHECK(b.is_contiguous());
  TORCH_CHECK(c.is_contiguous());

#if !defined(GROUPED_GEMM_CUTLASS)
  CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
  return;
#else
  // The `coord_template` argument contains `kDynamicDim` as one of its dimensions
  // as a placeholder. This placeholder is later expanded into the actual dimension
  // for every element of the batch,  either on the host or on the device
  // (if we can't do in on the host).
  const auto coord_template = trans_a
    ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
    : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
  if (trans_a) {
    CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
    return;
  }
  if (trans_b) {
    CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
    return;
  }
  CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
  return;
#endif
}

}  // namespace grouped_gemm

#endif