import os import pathlib import shutil from torch.utils.cpp_extension import load try: from kernels.utils import build_variant except ImportError: # fallback when kernels is unavailable build_variant = None repo = pathlib.Path(__file__).resolve().parent os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions")) sources = [ repo / "torch-ext" / "torch_binding.cpp", repo / "csrc" / "new_cumsum.cu", repo / "csrc" / "new_histogram.cu", repo / "csrc" / "new_indices.cu", repo / "csrc" / "new_replicate.cu", repo / "csrc" / "new_sort.cu", repo / "csrc" / "grouped_gemm" / "grouped_gemm.cu", ] mod = load( name="_megablocks_rocm", sources=[str(s) for s in sources], extra_include_paths=[str(repo / "csrc")], extra_cflags=["-O3", "-std=c++17"], extra_cuda_cflags=["-O3"], # torch switches this to hipcc flags on ROCm builds extra_ldflags=["-lhipblaslt"], verbose=True, is_python_module=False, ) module_path = pathlib.Path(mod if isinstance(mod, str) else mod.__file__) print("built:", module_path) if build_variant is None: print("kernels not available; skipping package staging") else: variant = build_variant() package_root = repo / "build" / variant / "megablocks" if package_root.exists(): shutil.rmtree(package_root) shutil.copytree( repo / "torch-ext" / "megablocks", package_root, ignore=shutil.ignore_patterns("__pycache__"), ) ops_py = package_root / "_ops.py" ops_py.write_text(''' import torch from pathlib import Path _LIB_NAME = "_megablocks_rocm.so" def _load_ops(): lib_path = Path(__file__).with_name(_LIB_NAME) torch.ops.load_library(str(lib_path)) return torch.ops._megablocks_rocm ops = _load_ops() def add_op_namespace_prefix(op_name: str) -> str: return f"_megablocks_rocm::{op_name}" ''') shutil.copy2(module_path, package_root / module_path.name) print(f"staged local kernel under build/{variant}/megablocks")