megablocks-hip / _dev /TODO-hip.md
leonardlin's picture
Gate ROCm grouped_gemm hipBLASLt behind env flag
867401e

HIP Grouped GEMM Status (2025-09-18)

Current toggle

  • Set MEGABLOCKS_GG_USE_HIPBLASLT=1 to force the ROCm build to run the hipBLASLt backend instead of the FP32 fallback in hipblaslt_gmm_internal.
  • Without the flag the code uses the stable FP32 torch::matmul path that overwrites the destination buffer.

What works with hipBLASLt enabled

  • _dev/debug-gg-small.py, _dev/debug-tensor-copy.py, and _dev/debug-gg-detailed.py finish with finite outputs (differences are within ~1e-3..1e-2 due to BF16).
  • python -m pytest tests/test_gg.py -q passes with the flag set.

Known failures

  • PYTHONPATH=build/... MEGABLOCKS_GG_USE_HIPBLASLT=1 python -m pytest tests/ops_test.py -q aborts with a HIP memory access fault (Memory access fault by GPU node-2 during OpsTest.testGroupedGemm_FixedSizes).
  • The same failure occurs early when the test suite is run via run-tests.sh, so hipBLASLt is not yet production-ready.

Next steps

  • Reproduce the fault in isolation (likely the large (z=16, m=128, k=128, n=128) cases) and inspect the arguments passed into hipblaslt_run_matmul (leading dimensions/layout).
  • Investigate whether hipBLASLt requires column-major layouts or non-zero workspace to handle the grouped GEMM shapes.
  • Consider hybrid strategy: attempt hipBLASLt per expert and fall back to FP32 for shapes that exceed stability thresholds (e.g., by catching hipblaslt_run_matmul errors once we can reliably detect them).
  • Once hipBLASLt is stable, tighten tolerances/grad checks in tests/test_gg.py and re-enable the high-performance path by default.