HIP Grouped GEMM Status (2025-09-18)
Current toggle
- Set
MEGABLOCKS_GG_USE_HIPBLASLT=1to force the ROCm build to run the hipBLASLt backend instead of the FP32 fallback inhipblaslt_gmm_internal. - Without the flag the code uses the stable FP32
torch::matmulpath that overwrites the destination buffer.
What works with hipBLASLt enabled
_dev/debug-gg-small.py,_dev/debug-tensor-copy.py, and_dev/debug-gg-detailed.pyfinish with finite outputs (differences are within ~1e-3..1e-2 due to BF16).python -m pytest tests/test_gg.py -qpasses with the flag set.
Known failures
PYTHONPATH=build/... MEGABLOCKS_GG_USE_HIPBLASLT=1 python -m pytest tests/ops_test.py -qaborts with a HIP memory access fault (Memory access fault by GPU node-2duringOpsTest.testGroupedGemm_FixedSizes).- The same failure occurs early when the test suite is run via
run-tests.sh, so hipBLASLt is not yet production-ready.
Next steps
- Reproduce the fault in isolation (likely the large
(z=16, m=128, k=128, n=128)cases) and inspect the arguments passed intohipblaslt_run_matmul(leading dimensions/layout). - Investigate whether hipBLASLt requires column-major layouts or non-zero workspace to handle the grouped GEMM shapes.
- Consider hybrid strategy: attempt hipBLASLt per expert and fall back to FP32 for shapes that exceed stability thresholds (e.g., by catching
hipblaslt_run_matmulerrors once we can reliably detect them). - Once hipBLASLt is stable, tighten tolerances/grad checks in
tests/test_gg.pyand re-enable the high-performance path by default.