Commit
·
3f13c8c
1
Parent(s):
e41a37b
add inference demo
Browse files- inference/README.md +13 -0
- inference/config_671B_v3.2.json +26 -0
- inference/convert.py +100 -0
- inference/generate.py +186 -0
- inference/kernel.py +274 -0
- inference/model.py +912 -0
- inference/requirements.txt +5 -0
inference/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek V3.2
|
| 2 |
+
|
| 3 |
+
First convert huggingface model weight files to the format of this project.
|
| 4 |
+
```bash
|
| 5 |
+
export EXPERTS=256
|
| 6 |
+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
Then chat with DeepSeek model at will!
|
| 10 |
+
```bash
|
| 11 |
+
export CONFIG=config_671B_v3.2.json
|
| 12 |
+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
|
| 13 |
+
```
|
inference/config_671B_v3.2.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 129280,
|
| 3 |
+
"dim": 7168,
|
| 4 |
+
"inter_dim": 18432,
|
| 5 |
+
"moe_inter_dim": 2048,
|
| 6 |
+
"n_layers": 61,
|
| 7 |
+
"n_dense_layers": 3,
|
| 8 |
+
"n_heads": 128,
|
| 9 |
+
"n_routed_experts": 256,
|
| 10 |
+
"n_shared_experts": 1,
|
| 11 |
+
"n_activated_experts": 8,
|
| 12 |
+
"n_expert_groups": 8,
|
| 13 |
+
"n_limited_groups": 4,
|
| 14 |
+
"route_scale": 2.5,
|
| 15 |
+
"score_func": "sigmoid",
|
| 16 |
+
"q_lora_rank": 1536,
|
| 17 |
+
"kv_lora_rank": 512,
|
| 18 |
+
"qk_nope_head_dim": 128,
|
| 19 |
+
"qk_rope_head_dim": 64,
|
| 20 |
+
"v_head_dim": 128,
|
| 21 |
+
"dtype": "fp8",
|
| 22 |
+
"scale_fmt": "ue8m0",
|
| 23 |
+
"index_n_heads": 64,
|
| 24 |
+
"index_head_dim": 128,
|
| 25 |
+
"index_topk": 2048
|
| 26 |
+
}
|
inference/convert.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from glob import glob
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import safe_open, save_file
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
mapping = {
|
| 12 |
+
"embed_tokens": ("embed", 0),
|
| 13 |
+
"input_layernorm": ("attn_norm", None),
|
| 14 |
+
"post_attention_layernorm": ("ffn_norm", None),
|
| 15 |
+
"q_proj": ("wq", 0),
|
| 16 |
+
"q_a_proj": ("wq_a", None),
|
| 17 |
+
"q_a_layernorm": ("q_norm", None),
|
| 18 |
+
"q_b_proj": ("wq_b", 0),
|
| 19 |
+
"kv_a_proj_with_mqa": ("wkv_a", None),
|
| 20 |
+
"kv_a_layernorm": ("kv_norm", None),
|
| 21 |
+
"kv_b_proj": ("wkv_b", 0),
|
| 22 |
+
"o_proj": ("wo", 1),
|
| 23 |
+
"gate": ("gate", None),
|
| 24 |
+
"gate_proj": ("w1", 0),
|
| 25 |
+
"down_proj": ("w2", 1),
|
| 26 |
+
"up_proj": ("w3", 0),
|
| 27 |
+
"norm": ("norm", None),
|
| 28 |
+
"lm_head": ("head", 0),
|
| 29 |
+
"scale": ("scale", None),
|
| 30 |
+
"wq_b": ("wq_b", None),
|
| 31 |
+
"wk": ("wk", None),
|
| 32 |
+
"k_norm": ("k_norm", None),
|
| 33 |
+
"weights_proj": ("weights_proj", None),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(hf_ckpt_path, save_path, n_experts, mp):
|
| 38 |
+
"""
|
| 39 |
+
Converts and saves model checkpoint files into a specified format.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
| 43 |
+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
| 44 |
+
n_experts (int): Total number of experts in the model.
|
| 45 |
+
mp (int): Model parallelism factor.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
None
|
| 49 |
+
"""
|
| 50 |
+
torch.set_num_threads(8)
|
| 51 |
+
n_local_experts = n_experts // mp
|
| 52 |
+
state_dicts = [{} for _ in range(mp)]
|
| 53 |
+
|
| 54 |
+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
| 55 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 56 |
+
for name in f.keys():
|
| 57 |
+
if "model.layers.61" in name:
|
| 58 |
+
continue
|
| 59 |
+
param: torch.Tensor = f.get_tensor(name)
|
| 60 |
+
if name.startswith("model."):
|
| 61 |
+
name = name[len("model."):]
|
| 62 |
+
name = name.replace("self_attn", "attn")
|
| 63 |
+
name = name.replace("mlp", "ffn")
|
| 64 |
+
name = name.replace("weight_scale_inv", "scale")
|
| 65 |
+
name = name.replace("e_score_correction_bias", "bias")
|
| 66 |
+
key = name.split(".")[-2]
|
| 67 |
+
assert key in mapping, f"Key {key} not found in mapping"
|
| 68 |
+
new_key, dim = mapping[key]
|
| 69 |
+
name = name.replace(key, new_key)
|
| 70 |
+
for i in range(mp):
|
| 71 |
+
new_param = param
|
| 72 |
+
if "experts" in name and "shared_experts" not in name:
|
| 73 |
+
idx = int(name.split(".")[-3])
|
| 74 |
+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
| 75 |
+
continue
|
| 76 |
+
elif dim is not None:
|
| 77 |
+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
| 78 |
+
shard_size = param.size(dim) // mp
|
| 79 |
+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
| 80 |
+
state_dicts[i][name] = new_param
|
| 81 |
+
|
| 82 |
+
os.makedirs(save_path, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
for i in trange(mp):
|
| 85 |
+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
| 86 |
+
|
| 87 |
+
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
| 88 |
+
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
| 89 |
+
shutil.copyfile(file_path, new_file_path)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = ArgumentParser()
|
| 94 |
+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
| 95 |
+
parser.add_argument("--save-path", type=str, required=True)
|
| 96 |
+
parser.add_argument("--n-experts", type=int, required=True)
|
| 97 |
+
parser.add_argument("--model-parallel", type=int, required=True)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
| 100 |
+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
inference/generate.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from safetensors.torch import load_model
|
| 10 |
+
|
| 11 |
+
from model import Transformer, ModelArgs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sample(logits, temperature: float = 1.0):
|
| 15 |
+
"""
|
| 16 |
+
Samples a token from the logits using temperature scaling.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
logits (torch.Tensor): The logits tensor for token predictions.
|
| 20 |
+
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
torch.Tensor: The sampled token.
|
| 24 |
+
"""
|
| 25 |
+
logits = logits / max(temperature, 1e-5)
|
| 26 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
| 27 |
+
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.inference_mode()
|
| 31 |
+
def generate(
|
| 32 |
+
model: Transformer,
|
| 33 |
+
prompt_tokens: List[List[int]],
|
| 34 |
+
max_new_tokens: int,
|
| 35 |
+
eos_id: int,
|
| 36 |
+
temperature: float = 1.0
|
| 37 |
+
) -> List[List[int]]:
|
| 38 |
+
"""
|
| 39 |
+
Generates new tokens based on the given prompt tokens using the specified model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model (Transformer): The transformer model used for token generation.
|
| 43 |
+
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
| 44 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
| 45 |
+
eos_id (int): The end-of-sequence token ID.
|
| 46 |
+
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
| 50 |
+
"""
|
| 51 |
+
prompt_lens = [len(t) for t in prompt_tokens]
|
| 52 |
+
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
| 53 |
+
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
| 54 |
+
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
| 55 |
+
for i, t in enumerate(prompt_tokens):
|
| 56 |
+
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
| 57 |
+
prev_pos = 0
|
| 58 |
+
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
| 59 |
+
prompt_mask = tokens != -1
|
| 60 |
+
for cur_pos in range(min(prompt_lens), total_len):
|
| 61 |
+
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
| 62 |
+
if temperature > 0:
|
| 63 |
+
next_token = sample(logits, temperature)
|
| 64 |
+
else:
|
| 65 |
+
next_token = logits.argmax(dim=-1)
|
| 66 |
+
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
| 67 |
+
tokens[:, cur_pos] = next_token
|
| 68 |
+
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
| 69 |
+
prev_pos = cur_pos
|
| 70 |
+
if finished.all():
|
| 71 |
+
break
|
| 72 |
+
completion_tokens = []
|
| 73 |
+
for i, toks in enumerate(tokens.tolist()):
|
| 74 |
+
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
| 75 |
+
if eos_id in toks:
|
| 76 |
+
toks = toks[:toks.index(eos_id)]
|
| 77 |
+
completion_tokens.append(toks)
|
| 78 |
+
return completion_tokens
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main(
|
| 82 |
+
ckpt_path: str,
|
| 83 |
+
config: str,
|
| 84 |
+
input_file: str = "",
|
| 85 |
+
interactive: bool = True,
|
| 86 |
+
max_new_tokens: int = 100,
|
| 87 |
+
temperature: float = 1.0,
|
| 88 |
+
) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Main function to load the model and perform interactive or batch text generation.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
ckpt_path (str): Path to the model checkpoint directory.
|
| 94 |
+
config (str): Path to the model configuration file.
|
| 95 |
+
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
| 96 |
+
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
| 97 |
+
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
| 98 |
+
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
| 99 |
+
"""
|
| 100 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 101 |
+
rank = int(os.getenv("RANK", "0"))
|
| 102 |
+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 103 |
+
if world_size > 1:
|
| 104 |
+
dist.init_process_group("nccl")
|
| 105 |
+
global print
|
| 106 |
+
if rank != 0:
|
| 107 |
+
print = lambda *_, **__: None
|
| 108 |
+
torch.cuda.set_device(local_rank)
|
| 109 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 110 |
+
torch.set_num_threads(8)
|
| 111 |
+
torch.manual_seed(33377335)
|
| 112 |
+
with open(config) as f:
|
| 113 |
+
args = ModelArgs(**json.load(f))
|
| 114 |
+
print(args)
|
| 115 |
+
with torch.device("cuda"):
|
| 116 |
+
model = Transformer(args)
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
| 118 |
+
print("load model")
|
| 119 |
+
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
| 120 |
+
print("I'm DeepSeek 👋")
|
| 121 |
+
|
| 122 |
+
if interactive:
|
| 123 |
+
messages = []
|
| 124 |
+
while True:
|
| 125 |
+
if world_size == 1:
|
| 126 |
+
prompt = input(">>> ")
|
| 127 |
+
elif rank == 0:
|
| 128 |
+
prompt = input(">>> ")
|
| 129 |
+
objects = [prompt]
|
| 130 |
+
dist.broadcast_object_list(objects, 0)
|
| 131 |
+
else:
|
| 132 |
+
objects = [None]
|
| 133 |
+
dist.broadcast_object_list(objects, 0)
|
| 134 |
+
prompt = objects[0]
|
| 135 |
+
if prompt == "/exit":
|
| 136 |
+
break
|
| 137 |
+
elif prompt == "/clear":
|
| 138 |
+
messages.clear()
|
| 139 |
+
continue
|
| 140 |
+
messages.append({"role": "user", "content": prompt})
|
| 141 |
+
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
| 142 |
+
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 143 |
+
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
| 144 |
+
print(completion)
|
| 145 |
+
messages.append({"role": "assistant", "content": completion})
|
| 146 |
+
else:
|
| 147 |
+
with open(input_file) as f:
|
| 148 |
+
prompts = f.read().split("\n\n")
|
| 149 |
+
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
| 150 |
+
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
| 151 |
+
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
| 152 |
+
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
| 153 |
+
for prompt, completion in zip(prompts, completions):
|
| 154 |
+
print("Prompt:", prompt)
|
| 155 |
+
print("Completion:", completion)
|
| 156 |
+
print()
|
| 157 |
+
|
| 158 |
+
if world_size > 1:
|
| 159 |
+
dist.destroy_process_group()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
"""
|
| 164 |
+
Command-line interface for distributed text generation.
|
| 165 |
+
|
| 166 |
+
Arguments:
|
| 167 |
+
--ckpt-path (str): Path to the model checkpoint directory.
|
| 168 |
+
--config (str): Path to the model configuration file.
|
| 169 |
+
--input-file (str, optional): File containing prompts for batch processing.
|
| 170 |
+
--interactive (bool, optional): Enable interactive mode for generating text.
|
| 171 |
+
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
| 172 |
+
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
| 173 |
+
|
| 174 |
+
Raises:
|
| 175 |
+
AssertionError: If neither input-file nor interactive mode is specified.
|
| 176 |
+
"""
|
| 177 |
+
parser = ArgumentParser()
|
| 178 |
+
parser.add_argument("--ckpt-path", type=str, required=True)
|
| 179 |
+
parser.add_argument("--config", type=str, required=True)
|
| 180 |
+
parser.add_argument("--input-file", type=str, default="")
|
| 181 |
+
parser.add_argument("--interactive", action="store_true")
|
| 182 |
+
parser.add_argument("--max-new-tokens", type=int, default=200)
|
| 183 |
+
parser.add_argument("--temperature", type=float, default=0.6)
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
| 186 |
+
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
inference/kernel.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tilelang
|
| 3 |
+
import tilelang.language as T
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
tilelang.set_log_level("WARNING")
|
| 8 |
+
|
| 9 |
+
pass_configs = {
|
| 10 |
+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
| 11 |
+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
| 12 |
+
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
FP8 = "float8_e4m3"
|
| 16 |
+
BF16 = "bfloat16"
|
| 17 |
+
FP32 = "float32"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fast_log2_ceil(x):
|
| 21 |
+
bits_x = T.reinterpret("uint32", x)
|
| 22 |
+
exp_x = (bits_x >> 23) & 0xFF
|
| 23 |
+
man_bits = bits_x & ((1 << 23) - 1)
|
| 24 |
+
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def fast_pow2(x):
|
| 28 |
+
bits_x = (x + 127) << 23
|
| 29 |
+
return T.reinterpret("float32", bits_x)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def fast_round_scale(amax, fp8_max_inv):
|
| 33 |
+
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 37 |
+
def act_quant_kernel(
|
| 38 |
+
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
| 39 |
+
):
|
| 40 |
+
M = T.symbolic("M")
|
| 41 |
+
fp8_min = -448.0
|
| 42 |
+
fp8_max = 448.0
|
| 43 |
+
fp8_max_inv = 1 / fp8_max
|
| 44 |
+
num_stages = 0 if round_scale else 2
|
| 45 |
+
blk_m = 32
|
| 46 |
+
group_size = 128
|
| 47 |
+
|
| 48 |
+
@T.prim_func
|
| 49 |
+
def act_quant_kernel_(
|
| 50 |
+
X: T.Tensor[(M, N), in_dtype],
|
| 51 |
+
Y: T.Tensor[(M, N), out_dtype],
|
| 52 |
+
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
| 53 |
+
):
|
| 54 |
+
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
| 55 |
+
pid_m,
|
| 56 |
+
pid_n,
|
| 57 |
+
):
|
| 58 |
+
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
| 59 |
+
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
| 60 |
+
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
|
| 61 |
+
s_local = T.alloc_fragment((blk_m,), scale_dtype)
|
| 62 |
+
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
| 63 |
+
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
| 64 |
+
|
| 65 |
+
for _ in T.Pipelined(1, num_stages=num_stages):
|
| 66 |
+
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
| 67 |
+
T.copy(x_shared, x_local)
|
| 68 |
+
T.reduce_absmax(x_local, amax_local, dim=1)
|
| 69 |
+
for i in T.Parallel(blk_m):
|
| 70 |
+
amax_local[i] = T.max(amax_local[i], 1e-4)
|
| 71 |
+
if round_scale:
|
| 72 |
+
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
| 73 |
+
else:
|
| 74 |
+
s_local[i] = amax_local[i] * fp8_max_inv
|
| 75 |
+
for i, j in T.Parallel(blk_m, group_size):
|
| 76 |
+
y_local[i, j] = T.clamp(
|
| 77 |
+
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
| 78 |
+
)
|
| 79 |
+
for i in T.Parallel(blk_m):
|
| 80 |
+
S[pid_m * blk_m + i, pid_n] = s_local[i]
|
| 81 |
+
T.copy(y_local, y_shared)
|
| 82 |
+
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
| 83 |
+
|
| 84 |
+
return act_quant_kernel_
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def act_quant(
|
| 88 |
+
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""
|
| 91 |
+
Quantizes the input tensor `x` using block-wise quantization.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
| 95 |
+
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
| 96 |
+
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 99 |
+
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
| 100 |
+
- A tensor of scaling factors with dtype `torch.float32`.
|
| 101 |
+
"""
|
| 102 |
+
assert x.is_contiguous(), "Input tensor must be contiguous"
|
| 103 |
+
assert x.size(-1) % block_size == 0, (
|
| 104 |
+
f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
| 105 |
+
)
|
| 106 |
+
N = x.size(-1)
|
| 107 |
+
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
| 108 |
+
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
| 109 |
+
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
|
| 110 |
+
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
| 111 |
+
return y, s
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@tilelang.jit(pass_configs=pass_configs)
|
| 115 |
+
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
|
| 116 |
+
assert out_dtype in [BF16, "float32"]
|
| 117 |
+
|
| 118 |
+
M = T.symbolic("M")
|
| 119 |
+
group_size = 128
|
| 120 |
+
block_M = 32
|
| 121 |
+
block_N = 128
|
| 122 |
+
block_K = 128
|
| 123 |
+
|
| 124 |
+
@T.prim_func
|
| 125 |
+
def fp8_gemm_kernel_(
|
| 126 |
+
A: T.Tensor[(M, K), FP8],
|
| 127 |
+
B: T.Tensor[(N, K), FP8],
|
| 128 |
+
C: T.Tensor[(M, N), out_dtype],
|
| 129 |
+
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
|
| 130 |
+
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
|
| 131 |
+
):
|
| 132 |
+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
|
| 133 |
+
bx,
|
| 134 |
+
by,
|
| 135 |
+
):
|
| 136 |
+
A_shared = T.alloc_shared((block_M, block_K), FP8)
|
| 137 |
+
B_shared = T.alloc_shared((block_N, block_K), FP8)
|
| 138 |
+
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
|
| 139 |
+
Scale_C_shared = T.alloc_shared((block_M), FP32)
|
| 140 |
+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 141 |
+
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
|
| 142 |
+
|
| 143 |
+
# Improve L2 Cache
|
| 144 |
+
T.use_swizzle(panel_size=10)
|
| 145 |
+
|
| 146 |
+
T.clear(C_local)
|
| 147 |
+
T.clear(C_local_accum)
|
| 148 |
+
K_iters = T.ceildiv(K, block_K)
|
| 149 |
+
for k in T.Pipelined(K_iters, num_stages=4):
|
| 150 |
+
# Load A into shared memory
|
| 151 |
+
T.copy(A[by * block_M, k * block_K], A_shared)
|
| 152 |
+
# Load B into shared memory
|
| 153 |
+
T.copy(B[bx * block_N, k * block_K], B_shared)
|
| 154 |
+
# Load scale into shared memory
|
| 155 |
+
Scale_B = scales_b[bx * block_N // group_size, k]
|
| 156 |
+
for i in T.Parallel(block_M):
|
| 157 |
+
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
|
| 158 |
+
|
| 159 |
+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
|
| 160 |
+
# Promote to enable 2xAcc
|
| 161 |
+
for i, j in T.Parallel(block_M, block_N):
|
| 162 |
+
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
|
| 163 |
+
T.clear(C_local)
|
| 164 |
+
# TMA store
|
| 165 |
+
T.copy(C_local_accum, C_shared)
|
| 166 |
+
T.copy(C_shared, C[by * block_M, bx * block_N])
|
| 167 |
+
|
| 168 |
+
return fp8_gemm_kernel_
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def fp8_gemm(
|
| 172 |
+
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
"""
|
| 175 |
+
Perform a matrix multiplication using FP8 precision.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
a (torch.Tensor): The first input matrix, must be contiguous.
|
| 179 |
+
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
| 180 |
+
b (torch.Tensor): The second input matrix, must be contiguous.
|
| 181 |
+
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
torch.Tensor: The result of the matrix multiplication.
|
| 185 |
+
"""
|
| 186 |
+
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
| 187 |
+
assert a_s.is_contiguous() and b_s.is_contiguous(), (
|
| 188 |
+
"Scaling factor tensors must be contiguous"
|
| 189 |
+
)
|
| 190 |
+
K = a.size(-1)
|
| 191 |
+
M = a.numel() // K
|
| 192 |
+
N = b.size(0)
|
| 193 |
+
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
| 194 |
+
kernel = fp8_gemm_kernel(N, K)
|
| 195 |
+
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
|
| 196 |
+
return c
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
| 200 |
+
def fp8_index_kernel(h: int, d: int):
|
| 201 |
+
b = T.symbolic("b")
|
| 202 |
+
m = T.symbolic("m")
|
| 203 |
+
n = T.symbolic("n")
|
| 204 |
+
|
| 205 |
+
blk_n1 = 512
|
| 206 |
+
blk_n2 = 128
|
| 207 |
+
|
| 208 |
+
@T.prim_func
|
| 209 |
+
def fp8_index_kernel_(
|
| 210 |
+
q: T.Tensor[(b, m, h, d), FP8],
|
| 211 |
+
q_s: T.Tensor[(b, m, h), FP32],
|
| 212 |
+
k: T.Tensor[(b, n, d), FP8],
|
| 213 |
+
k_s: T.Tensor[(b, n), FP32],
|
| 214 |
+
o: T.Tensor[(b, m, n), FP32],
|
| 215 |
+
) -> None:
|
| 216 |
+
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
| 217 |
+
q_smem = T.alloc_shared((h, d), FP8)
|
| 218 |
+
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
| 219 |
+
|
| 220 |
+
q_s_frag = T.alloc_fragment(h, FP32)
|
| 221 |
+
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
| 222 |
+
|
| 223 |
+
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
| 224 |
+
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
| 225 |
+
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
| 226 |
+
|
| 227 |
+
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
| 228 |
+
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
| 229 |
+
|
| 230 |
+
logits = T.alloc_fragment((blk_n2, h), FP32)
|
| 231 |
+
T.gemm(
|
| 232 |
+
k_smem,
|
| 233 |
+
q_smem,
|
| 234 |
+
logits,
|
| 235 |
+
transpose_A=False,
|
| 236 |
+
transpose_B=True,
|
| 237 |
+
clear_accum=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
for i_h, i3_n in T.Parallel(h, blk_n2):
|
| 241 |
+
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
| 242 |
+
|
| 243 |
+
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
| 244 |
+
T.reduce_sum(logits, logits_sum, dim=1)
|
| 245 |
+
|
| 246 |
+
for i3_n in T.Parallel(blk_n2):
|
| 247 |
+
logits_sum[i3_n] *= k_s_frag[i3_n]
|
| 248 |
+
|
| 249 |
+
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
| 250 |
+
|
| 251 |
+
return fp8_index_kernel_
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def fp8_index(
|
| 255 |
+
q: torch.Tensor,
|
| 256 |
+
q_s: torch.Tensor,
|
| 257 |
+
k: torch.Tensor,
|
| 258 |
+
k_s: torch.Tensor,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Perform index score using FP8 precision.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
q (torch.Tensor): The Q tensor, must be contiguous.
|
| 265 |
+
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
| 266 |
+
k (torch.Tensor): The K tensor, must be contiguous.
|
| 267 |
+
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
| 268 |
+
|
| 269 |
+
fp8 q @ fp8 k -> fp32 logits
|
| 270 |
+
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
| 271 |
+
fp32 logits -> fp32 logits_sum
|
| 272 |
+
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
| 273 |
+
"""
|
| 274 |
+
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
inference/model.py
ADDED
|
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Optional, Literal
|
| 4 |
+
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
from kernel import act_quant, fp8_gemm, fp8_index
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
world_size = 1
|
| 15 |
+
rank = 0
|
| 16 |
+
block_size = 128
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class ModelArgs:
|
| 20 |
+
"""
|
| 21 |
+
Data class for defining model arguments and hyperparameters.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
max_batch_size (int): Maximum batch size.
|
| 25 |
+
max_seq_len (int): Maximum sequence length.
|
| 26 |
+
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
| 27 |
+
scale_fmt (Optional[str]): Format for quantization scale.
|
| 28 |
+
vocab_size (int): Vocabulary size.
|
| 29 |
+
dim (int): Model dimension.
|
| 30 |
+
inter_dim (int): Intermediate dimension for MLP layers.
|
| 31 |
+
moe_inter_dim (int): Intermediate dimension for MoE layers.
|
| 32 |
+
n_layers (int): Number of transformer layers.
|
| 33 |
+
n_dense_layers (int): Number of dense layers in the model.
|
| 34 |
+
n_heads (int): Number of attention heads.
|
| 35 |
+
n_routed_experts (int): Number of routed experts for MoE layers.
|
| 36 |
+
n_shared_experts (int): Number of shared experts for MoE layers.
|
| 37 |
+
n_activated_experts (int): Number of activated experts in MoE layers.
|
| 38 |
+
n_expert_groups (int): Number of expert groups.
|
| 39 |
+
n_limited_groups (int): Number of limited groups for MoE routing.
|
| 40 |
+
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
|
| 41 |
+
route_scale (float): Scaling factor for routing scores.
|
| 42 |
+
q_lora_rank (int): LoRA rank for query projections.
|
| 43 |
+
kv_lora_rank (int): LoRA rank for key-value projections.
|
| 44 |
+
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
|
| 45 |
+
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
|
| 46 |
+
v_head_dim (int): Dimension for value projections.
|
| 47 |
+
original_seq_len (int): Original sequence length.
|
| 48 |
+
rope_theta (float): Base for rotary positional encoding.
|
| 49 |
+
rope_factor (float): Scaling factor for extended sequence lengths.
|
| 50 |
+
beta_fast (int): Fast beta correction factor.
|
| 51 |
+
beta_slow (int): Slow beta correction factor.
|
| 52 |
+
mscale (float): Scaling factor for extended attention.
|
| 53 |
+
index_head_dim (int): Dimension for index head.
|
| 54 |
+
index_topk (int): Top-k for index head.
|
| 55 |
+
"""
|
| 56 |
+
max_batch_size: int = 8
|
| 57 |
+
max_seq_len: int = 4096 * 4
|
| 58 |
+
dtype: Literal["bf16", "fp8"] = "bf16"
|
| 59 |
+
scale_fmt: Optional[str] = None
|
| 60 |
+
vocab_size: int = 102400
|
| 61 |
+
dim: int = 2048
|
| 62 |
+
inter_dim: int = 10944
|
| 63 |
+
moe_inter_dim: int = 1408
|
| 64 |
+
n_layers: int = 27
|
| 65 |
+
n_dense_layers: int = 1
|
| 66 |
+
n_heads: int = 16
|
| 67 |
+
# moe
|
| 68 |
+
n_routed_experts: int = 64
|
| 69 |
+
n_shared_experts: int = 2
|
| 70 |
+
n_activated_experts: int = 6
|
| 71 |
+
n_expert_groups: int = 1
|
| 72 |
+
n_limited_groups: int = 1
|
| 73 |
+
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
| 74 |
+
route_scale: float = 1.
|
| 75 |
+
# mla
|
| 76 |
+
q_lora_rank: int = 0
|
| 77 |
+
kv_lora_rank: int = 512
|
| 78 |
+
qk_nope_head_dim: int = 128
|
| 79 |
+
qk_rope_head_dim: int = 64
|
| 80 |
+
v_head_dim: int = 128
|
| 81 |
+
# yarn
|
| 82 |
+
original_seq_len: int = 4096
|
| 83 |
+
rope_theta: float = 10000.0
|
| 84 |
+
rope_factor: float = 40
|
| 85 |
+
beta_fast: int = 32
|
| 86 |
+
beta_slow: int = 1
|
| 87 |
+
mscale: float = 1.
|
| 88 |
+
# index
|
| 89 |
+
index_n_heads: int = 64
|
| 90 |
+
index_head_dim: int = 128
|
| 91 |
+
index_topk: int = 2048
|
| 92 |
+
|
| 93 |
+
class ParallelEmbedding(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
Embedding layer with parallelism support across distributed processes.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
vocab_size (int): Vocabulary size.
|
| 99 |
+
dim (int): Embedding dimension.
|
| 100 |
+
"""
|
| 101 |
+
def __init__(self, vocab_size: int, dim: int):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.vocab_size = vocab_size
|
| 104 |
+
self.dim = dim
|
| 105 |
+
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
| 106 |
+
self.part_vocab_size = (vocab_size // world_size)
|
| 107 |
+
self.vocab_start_idx = rank * self.part_vocab_size
|
| 108 |
+
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
| 109 |
+
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
| 110 |
+
|
| 111 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 112 |
+
"""
|
| 113 |
+
Forward pass for parallel embedding layer.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
x (torch.Tensor): Input tensor containing token indices.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
torch.Tensor: Embedded representations.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
ValueError: If `world_size` is not defined.
|
| 123 |
+
"""
|
| 124 |
+
if world_size > 1:
|
| 125 |
+
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
| 126 |
+
x = x - self.vocab_start_idx
|
| 127 |
+
x[mask] = 0
|
| 128 |
+
y = F.embedding(x, self.weight)
|
| 129 |
+
if world_size > 1:
|
| 130 |
+
y[mask] = 0
|
| 131 |
+
dist.all_reduce(y)
|
| 132 |
+
return y
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
|
| 136 |
+
scale_fmt: Optional[str] = None) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Applies a linear transformation to the incoming data: y = xA^T + b.
|
| 139 |
+
This function supports specialized implementations based on quantization
|
| 140 |
+
and tensor formats.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
x (torch.Tensor): The input tensor.
|
| 144 |
+
weight (torch.Tensor): The weight tensor. It may be quantized and
|
| 145 |
+
requires dequantization for certain cases.
|
| 146 |
+
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
| 147 |
+
scale_fmt (Optional[str]): The format of scaling factors.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
torch.Tensor: The result of the linear transformation, which may involve
|
| 151 |
+
quantization-aware computations depending on the input parameters.
|
| 152 |
+
|
| 153 |
+
Notes:
|
| 154 |
+
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
| 155 |
+
is used for computation.
|
| 156 |
+
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
| 157 |
+
"""
|
| 158 |
+
assert bias is None
|
| 159 |
+
|
| 160 |
+
if weight.dtype != torch.float8_e4m3fn:
|
| 161 |
+
return F.linear(x, weight)
|
| 162 |
+
else:
|
| 163 |
+
x, scale = act_quant(x, block_size, scale_fmt)
|
| 164 |
+
return fp8_gemm(x, scale, weight, weight.scale)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Linear(nn.Module):
|
| 168 |
+
"""
|
| 169 |
+
Custom linear layer with support for quantized weights and optional bias.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
in_features (int): Number of input features.
|
| 173 |
+
out_features (int): Number of output features.
|
| 174 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 175 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 176 |
+
"""
|
| 177 |
+
dtype = torch.bfloat16
|
| 178 |
+
scale_fmt: Optional[str] = None
|
| 179 |
+
|
| 180 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.in_features = in_features
|
| 183 |
+
self.out_features = out_features
|
| 184 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
|
| 185 |
+
if self.weight.element_size() == 1:
|
| 186 |
+
scale_out_features = (out_features + block_size - 1) // block_size
|
| 187 |
+
scale_in_features = (in_features + block_size - 1) // block_size
|
| 188 |
+
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
|
| 189 |
+
else:
|
| 190 |
+
self.register_parameter("scale", None)
|
| 191 |
+
if bias:
|
| 192 |
+
self.bias = nn.Parameter(torch.empty(out_features))
|
| 193 |
+
else:
|
| 194 |
+
self.register_parameter("bias", None)
|
| 195 |
+
|
| 196 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 197 |
+
"""
|
| 198 |
+
Forward pass for the custom linear layer.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
x (torch.Tensor): Input tensor.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
torch.Tensor: Transformed tensor after linear computation.
|
| 205 |
+
"""
|
| 206 |
+
return linear(x, self.weight, self.bias, self.scale_fmt)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class ColumnParallelLinear(Linear):
|
| 210 |
+
"""
|
| 211 |
+
Linear layer with column parallelism, splitting output features across distributed processes.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
in_features (int): Number of input features.
|
| 215 |
+
out_features (int): Total number of output features.
|
| 216 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 217 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 218 |
+
"""
|
| 219 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
| 220 |
+
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
| 221 |
+
self.part_out_features = out_features // world_size
|
| 222 |
+
super().__init__(in_features, self.part_out_features, bias, dtype)
|
| 223 |
+
|
| 224 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 225 |
+
"""
|
| 226 |
+
Forward pass for column parallel linear layer.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
x (torch.Tensor): Input tensor.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
torch.Tensor: Transformed tensor with column-parallel computation.
|
| 233 |
+
"""
|
| 234 |
+
y = linear(x, self.weight, self.bias, self.scale_fmt)
|
| 235 |
+
return y
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class RowParallelLinear(Linear):
|
| 239 |
+
"""
|
| 240 |
+
Linear layer with row parallelism, splitting input features across distributed processes.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
in_features (int): Total number of input features.
|
| 244 |
+
out_features (int): Number of output features.
|
| 245 |
+
bias (bool): Whether to include a bias term. Defaults to False.
|
| 246 |
+
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
| 247 |
+
"""
|
| 248 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None):
|
| 249 |
+
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
| 250 |
+
self.part_in_features = in_features // world_size
|
| 251 |
+
self.reduce_output = reduce_output
|
| 252 |
+
super().__init__(self.part_in_features, out_features, bias, dtype)
|
| 253 |
+
|
| 254 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 255 |
+
"""
|
| 256 |
+
Forward pass for row parallel linear layer.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
x (torch.Tensor): Input tensor.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
torch.Tensor: Transformed tensor with row-parallel computation.
|
| 263 |
+
"""
|
| 264 |
+
y = linear(x, self.weight, None, self.scale_fmt)
|
| 265 |
+
if self.reduce_output and world_size > 1:
|
| 266 |
+
y = y.float()
|
| 267 |
+
dist.all_reduce(y)
|
| 268 |
+
if self.bias is not None:
|
| 269 |
+
y += self.bias
|
| 270 |
+
return y.type_as(x)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class RMSNorm(nn.Module):
|
| 274 |
+
"""
|
| 275 |
+
Root Mean Square Layer Normalization (RMSNorm).
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
dim (int): Dimension of the input tensor.
|
| 279 |
+
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
| 280 |
+
"""
|
| 281 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.dim = dim
|
| 284 |
+
self.eps = eps
|
| 285 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 286 |
+
|
| 287 |
+
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
| 288 |
+
"""
|
| 289 |
+
Forward pass for RMSNorm.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
x (torch.Tensor): Input tensor.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
torch.Tensor: Normalized tensor with the same shape as input.
|
| 296 |
+
"""
|
| 297 |
+
dtype = x.dtype
|
| 298 |
+
if residual is None:
|
| 299 |
+
x = x.float()
|
| 300 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 301 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 302 |
+
return (self.weight * x).to(dtype)
|
| 303 |
+
else:
|
| 304 |
+
x = residual = x.float() + residual.float()
|
| 305 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 306 |
+
x = x * torch.rsqrt(var + self.eps)
|
| 307 |
+
return (self.weight * x).to(dtype), residual.to(dtype)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class LayerNorm(nn.Module):
|
| 311 |
+
"""
|
| 312 |
+
Layer Normalization.
|
| 313 |
+
"""
|
| 314 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.dim = dim
|
| 317 |
+
self.eps = eps
|
| 318 |
+
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
| 319 |
+
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
| 320 |
+
|
| 321 |
+
def forward(self, x: torch.Tensor):
|
| 322 |
+
return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
| 326 |
+
"""
|
| 327 |
+
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
args (ModelArgs): Model arguments containing positional embedding parameters.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
| 334 |
+
"""
|
| 335 |
+
dim = args.qk_rope_head_dim
|
| 336 |
+
seqlen = args.max_seq_len
|
| 337 |
+
beta_fast = args.beta_fast
|
| 338 |
+
beta_slow = args.beta_slow
|
| 339 |
+
base = args.rope_theta
|
| 340 |
+
factor = args.rope_factor
|
| 341 |
+
|
| 342 |
+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
| 343 |
+
"""
|
| 344 |
+
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
num_rotations (float): Number of rotations to compute the correction for.
|
| 348 |
+
dim (int): Dimensionality of the embedding space.
|
| 349 |
+
base (float): Base value for the exponential computation.
|
| 350 |
+
max_seq_len (int): Maximum sequence length.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
float: The correction dimension based on the input parameters.
|
| 354 |
+
"""
|
| 355 |
+
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
| 356 |
+
|
| 357 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
| 358 |
+
"""
|
| 359 |
+
Computes the range of correction dimensions for rotary positional embeddings.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
low_rot (float): Lower bound for the number of rotations.
|
| 363 |
+
high_rot (float): Upper bound for the number of rotations.
|
| 364 |
+
dim (int): Dimensionality of the embedding space.
|
| 365 |
+
base (float): Base value for the exponential computation.
|
| 366 |
+
max_seq_len (int): Maximum sequence length.
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
| 370 |
+
"""
|
| 371 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
| 372 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
| 373 |
+
return max(low, 0), min(high, dim-1)
|
| 374 |
+
|
| 375 |
+
def linear_ramp_factor(min, max, dim):
|
| 376 |
+
"""
|
| 377 |
+
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
min (float): Minimum value for the ramp function.
|
| 381 |
+
max (float): Maximum value for the ramp function.
|
| 382 |
+
dim (int): Dimensionality of the ramp tensor.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
| 386 |
+
clamped to the range [0, 1].
|
| 387 |
+
"""
|
| 388 |
+
if min == max:
|
| 389 |
+
max += 0.001
|
| 390 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 391 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 392 |
+
return ramp_func
|
| 393 |
+
|
| 394 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 395 |
+
if seqlen > args.original_seq_len:
|
| 396 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
| 397 |
+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
| 398 |
+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
| 399 |
+
|
| 400 |
+
t = torch.arange(seqlen)
|
| 401 |
+
freqs = torch.outer(t, freqs)
|
| 402 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 403 |
+
return freqs_cis
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 407 |
+
"""
|
| 408 |
+
Applies rotary positional embeddings to the input tensor.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
| 412 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
torch.Tensor: Tensor with rotary embeddings applied.
|
| 416 |
+
"""
|
| 417 |
+
dtype = x.dtype
|
| 418 |
+
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
| 419 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
| 420 |
+
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 421 |
+
return y.to(dtype)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 425 |
+
assert x.dtype == torch.bfloat16
|
| 426 |
+
from fast_hadamard_transform import hadamard_transform
|
| 427 |
+
hidden_size = x.size(-1)
|
| 428 |
+
return hadamard_transform(x, scale=hidden_size ** -0.5)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class Indexer(torch.nn.Module):
|
| 432 |
+
def __init__(self, args: ModelArgs):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.dim: int = args.dim
|
| 435 |
+
self.n_heads: int = args.index_n_heads
|
| 436 |
+
self.n_local_heads = args.index_n_heads // world_size
|
| 437 |
+
self.head_dim: int = args.index_head_dim
|
| 438 |
+
self.rope_head_dim: int = args.qk_rope_head_dim
|
| 439 |
+
self.index_topk: int = args.index_topk
|
| 440 |
+
self.q_lora_rank: int = args.q_lora_rank
|
| 441 |
+
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
|
| 442 |
+
self.wk = Linear(self.dim, self.head_dim)
|
| 443 |
+
self.k_norm = LayerNorm(self.head_dim)
|
| 444 |
+
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
|
| 445 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 446 |
+
self.scale_fmt = args.scale_fmt
|
| 447 |
+
|
| 448 |
+
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn), persistent=False)
|
| 449 |
+
self.register_buffer("k_scale_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.head_dim // block_size, dtype=torch.float32), persistent=False)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
| 453 |
+
bsz, seqlen, _ = x.size()
|
| 454 |
+
end_pos = start_pos + seqlen
|
| 455 |
+
q = self.wq_b(qr)
|
| 456 |
+
q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
|
| 457 |
+
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 458 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
| 459 |
+
q = torch.cat([q_pe, q_nope], dim=-1)
|
| 460 |
+
k = self.wk(x)
|
| 461 |
+
k = self.k_norm(k)
|
| 462 |
+
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
| 463 |
+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
|
| 464 |
+
k = torch.cat([k_pe, k_nope], dim=-1)
|
| 465 |
+
q = rotate_activation(q)
|
| 466 |
+
k = rotate_activation(k)
|
| 467 |
+
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
|
| 468 |
+
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
|
| 469 |
+
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
|
| 470 |
+
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
|
| 471 |
+
weights = self.weights_proj(x) * self.n_heads ** -0.5
|
| 472 |
+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
| 473 |
+
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
|
| 474 |
+
if mask is not None:
|
| 475 |
+
index_score += mask
|
| 476 |
+
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
|
| 477 |
+
topk_indices_ = topk_indices.clone()
|
| 478 |
+
dist.broadcast(topk_indices_, src=0)
|
| 479 |
+
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
|
| 480 |
+
return topk_indices
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def weight_dequant(weight, scale):
|
| 484 |
+
shape = weight.shape
|
| 485 |
+
assert weight.dim() == 2
|
| 486 |
+
weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
|
| 487 |
+
weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(shape[0] // block_size, shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(shape)
|
| 488 |
+
return weight
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class MLA(nn.Module):
|
| 492 |
+
"""
|
| 493 |
+
Multi-Head Latent Attention (MLA) Layer.
|
| 494 |
+
|
| 495 |
+
Attributes:
|
| 496 |
+
dim (int): Dimensionality of the input features.
|
| 497 |
+
n_heads (int): Number of attention heads.
|
| 498 |
+
n_local_heads (int): Number of local attention heads for distributed systems.
|
| 499 |
+
q_lora_rank (int): Rank for low-rank query projection.
|
| 500 |
+
kv_lora_rank (int): Rank for low-rank key/value projection.
|
| 501 |
+
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
|
| 502 |
+
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
|
| 503 |
+
qk_head_dim (int): Total dimensionality of query/key projections.
|
| 504 |
+
v_head_dim (int): Dimensionality of value projections.
|
| 505 |
+
softmax_scale (float): Scaling factor for softmax in attention computation.
|
| 506 |
+
"""
|
| 507 |
+
def __init__(self, args: ModelArgs):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.dim = args.dim
|
| 510 |
+
self.n_heads = args.n_heads
|
| 511 |
+
self.n_local_heads = args.n_heads // world_size
|
| 512 |
+
self.q_lora_rank = args.q_lora_rank
|
| 513 |
+
self.kv_lora_rank = args.kv_lora_rank
|
| 514 |
+
self.qk_nope_head_dim = args.qk_nope_head_dim
|
| 515 |
+
self.qk_rope_head_dim = args.qk_rope_head_dim
|
| 516 |
+
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
|
| 517 |
+
self.v_head_dim = args.v_head_dim
|
| 518 |
+
|
| 519 |
+
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
| 520 |
+
self.q_norm = RMSNorm(self.q_lora_rank)
|
| 521 |
+
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
| 522 |
+
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 523 |
+
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
| 524 |
+
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
| 525 |
+
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
| 526 |
+
self.softmax_scale = self.qk_head_dim ** -0.5
|
| 527 |
+
if args.max_seq_len > args.original_seq_len:
|
| 528 |
+
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
| 529 |
+
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 530 |
+
|
| 531 |
+
self.indexer = Indexer(args)
|
| 532 |
+
|
| 533 |
+
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
| 534 |
+
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
| 535 |
+
self.dequant_wkv_b = None
|
| 536 |
+
|
| 537 |
+
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
| 538 |
+
"""
|
| 539 |
+
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
| 543 |
+
start_pos (int): Starting position in the sequence for caching.
|
| 544 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 545 |
+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
torch.Tensor: Output tensor with the same shape as the input.
|
| 549 |
+
"""
|
| 550 |
+
bsz, seqlen, _ = x.size()
|
| 551 |
+
end_pos = start_pos + seqlen
|
| 552 |
+
qr = self.q_norm(self.wq_a(x))
|
| 553 |
+
q = self.wq_b(qr)
|
| 554 |
+
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
| 555 |
+
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 556 |
+
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
| 557 |
+
kv = self.wkv_a(x)
|
| 558 |
+
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 559 |
+
kv = self.kv_norm(kv)
|
| 560 |
+
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
| 561 |
+
self.kv_cache[:bsz, start_pos:end_pos] = kv
|
| 562 |
+
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
| 563 |
+
if mask is not None: # MHA prefill
|
| 564 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 565 |
+
kv = self.wkv_b(kv)
|
| 566 |
+
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 567 |
+
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 568 |
+
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
| 569 |
+
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
|
| 570 |
+
|
| 571 |
+
# indexer
|
| 572 |
+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 573 |
+
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 574 |
+
index_mask += mask
|
| 575 |
+
scores += index_mask.unsqueeze(2)
|
| 576 |
+
|
| 577 |
+
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
| 578 |
+
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
|
| 579 |
+
else: # MHA decode
|
| 580 |
+
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
|
| 581 |
+
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
|
| 582 |
+
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
|
| 583 |
+
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
| 584 |
+
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
| 585 |
+
scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) +
|
| 586 |
+
torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
|
| 587 |
+
|
| 588 |
+
# indexer
|
| 589 |
+
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
| 590 |
+
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
| 591 |
+
scores += index_mask.unsqueeze(2)
|
| 592 |
+
|
| 593 |
+
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
| 594 |
+
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
|
| 595 |
+
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
| 596 |
+
x = self.wo(x.flatten(2))
|
| 597 |
+
return x
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class MLP(nn.Module):
|
| 601 |
+
"""
|
| 602 |
+
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
|
| 603 |
+
|
| 604 |
+
Attributes:
|
| 605 |
+
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
| 606 |
+
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
| 607 |
+
w3 (nn.Module): Additional linear layer for feature transformation.
|
| 608 |
+
"""
|
| 609 |
+
def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
|
| 610 |
+
"""
|
| 611 |
+
Initializes the MLP layer.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
dim (int): Input and output dimensionality.
|
| 615 |
+
inter_dim (int): Hidden layer dimensionality.
|
| 616 |
+
"""
|
| 617 |
+
super().__init__()
|
| 618 |
+
self.w1 = ColumnParallelLinear(dim, inter_dim)
|
| 619 |
+
self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
|
| 620 |
+
self.w3 = ColumnParallelLinear(dim, inter_dim)
|
| 621 |
+
|
| 622 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 623 |
+
"""
|
| 624 |
+
Forward pass for the MLP layer.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
x (torch.Tensor): Input tensor.
|
| 628 |
+
|
| 629 |
+
Returns:
|
| 630 |
+
torch.Tensor: Output tensor after MLP computation.
|
| 631 |
+
"""
|
| 632 |
+
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class Gate(nn.Module):
|
| 636 |
+
"""
|
| 637 |
+
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
|
| 638 |
+
|
| 639 |
+
Attributes:
|
| 640 |
+
dim (int): Dimensionality of input features.
|
| 641 |
+
topk (int): Number of top experts activated for each input.
|
| 642 |
+
n_groups (int): Number of groups for routing.
|
| 643 |
+
topk_groups (int): Number of groups to route inputs to.
|
| 644 |
+
score_func (str): Scoring function ('softmax' or 'sigmoid').
|
| 645 |
+
route_scale (float): Scaling factor for routing weights.
|
| 646 |
+
weight (torch.nn.Parameter): Learnable weights for the gate.
|
| 647 |
+
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
| 648 |
+
"""
|
| 649 |
+
def __init__(self, args: ModelArgs):
|
| 650 |
+
"""
|
| 651 |
+
Initializes the Gate module.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
args (ModelArgs): Model arguments containing gating parameters.
|
| 655 |
+
"""
|
| 656 |
+
super().__init__()
|
| 657 |
+
self.dim = args.dim
|
| 658 |
+
self.topk = args.n_activated_experts
|
| 659 |
+
self.n_groups = args.n_expert_groups
|
| 660 |
+
self.topk_groups = args.n_limited_groups
|
| 661 |
+
self.score_func = args.score_func
|
| 662 |
+
self.route_scale = args.route_scale
|
| 663 |
+
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
| 664 |
+
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
|
| 665 |
+
|
| 666 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 667 |
+
"""
|
| 668 |
+
Forward pass for the gating mechanism.
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
x (torch.Tensor): Input tensor.
|
| 672 |
+
|
| 673 |
+
Returns:
|
| 674 |
+
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
| 675 |
+
"""
|
| 676 |
+
scores = linear(x.float(), self.weight.float())
|
| 677 |
+
if self.score_func == "softmax":
|
| 678 |
+
scores = scores.softmax(dim=-1)
|
| 679 |
+
else:
|
| 680 |
+
scores = scores.sigmoid()
|
| 681 |
+
original_scores = scores
|
| 682 |
+
if self.bias is not None:
|
| 683 |
+
scores = scores + self.bias
|
| 684 |
+
if self.n_groups > 1:
|
| 685 |
+
scores = scores.view(x.size(0), self.n_groups, -1)
|
| 686 |
+
if self.bias is None:
|
| 687 |
+
group_scores = scores.amax(dim=-1)
|
| 688 |
+
else:
|
| 689 |
+
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
| 690 |
+
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
| 691 |
+
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
|
| 692 |
+
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
| 693 |
+
indices = scores.topk(self.topk, dim=-1)[1]
|
| 694 |
+
weights = original_scores.gather(1, indices)
|
| 695 |
+
if self.score_func == "sigmoid":
|
| 696 |
+
weights /= weights.sum(dim=-1, keepdim=True)
|
| 697 |
+
weights *= self.route_scale
|
| 698 |
+
return weights, indices
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class Expert(nn.Module):
|
| 702 |
+
"""
|
| 703 |
+
Expert layer for Mixture-of-Experts (MoE) models.
|
| 704 |
+
|
| 705 |
+
Attributes:
|
| 706 |
+
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
| 707 |
+
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
| 708 |
+
w3 (nn.Module): Additional linear layer for feature transformation.
|
| 709 |
+
"""
|
| 710 |
+
def __init__(self, dim: int, inter_dim: int):
|
| 711 |
+
"""
|
| 712 |
+
Initializes the Expert layer.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
dim (int): Input and output dimensionality.
|
| 716 |
+
inter_dim (int): Hidden layer dimensionality.
|
| 717 |
+
"""
|
| 718 |
+
super().__init__()
|
| 719 |
+
self.w1 = Linear(dim, inter_dim)
|
| 720 |
+
self.w2 = Linear(inter_dim, dim)
|
| 721 |
+
self.w3 = Linear(dim, inter_dim)
|
| 722 |
+
|
| 723 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 724 |
+
"""
|
| 725 |
+
Forward pass for the Expert layer.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
x (torch.Tensor): Input tensor.
|
| 729 |
+
|
| 730 |
+
Returns:
|
| 731 |
+
torch.Tensor: Output tensor after expert computation.
|
| 732 |
+
"""
|
| 733 |
+
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class MoE(nn.Module):
|
| 737 |
+
"""
|
| 738 |
+
Mixture-of-Experts (MoE) module.
|
| 739 |
+
|
| 740 |
+
Attributes:
|
| 741 |
+
dim (int): Dimensionality of input features.
|
| 742 |
+
n_routed_experts (int): Total number of experts in the model.
|
| 743 |
+
n_local_experts (int): Number of experts handled locally in distributed systems.
|
| 744 |
+
n_activated_experts (int): Number of experts activated for each input.
|
| 745 |
+
gate (nn.Module): Gating mechanism to route inputs to experts.
|
| 746 |
+
experts (nn.ModuleList): List of expert modules.
|
| 747 |
+
shared_experts (nn.Module): Shared experts applied to all inputs.
|
| 748 |
+
"""
|
| 749 |
+
def __init__(self, args: ModelArgs):
|
| 750 |
+
"""
|
| 751 |
+
Initializes the MoE module.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
args (ModelArgs): Model arguments containing MoE parameters.
|
| 755 |
+
"""
|
| 756 |
+
super().__init__()
|
| 757 |
+
self.dim = args.dim
|
| 758 |
+
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
| 759 |
+
self.n_routed_experts = args.n_routed_experts
|
| 760 |
+
self.n_local_experts = args.n_routed_experts // world_size
|
| 761 |
+
self.n_activated_experts = args.n_activated_experts
|
| 762 |
+
self.experts_start_idx = rank * self.n_local_experts
|
| 763 |
+
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
| 764 |
+
self.gate = Gate(args)
|
| 765 |
+
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 766 |
+
for i in range(self.n_routed_experts)])
|
| 767 |
+
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
|
| 768 |
+
|
| 769 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 770 |
+
"""
|
| 771 |
+
Forward pass for the MoE module.
|
| 772 |
+
|
| 773 |
+
Args:
|
| 774 |
+
x (torch.Tensor): Input tensor.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
torch.Tensor: Output tensor after expert routing and computation.
|
| 778 |
+
"""
|
| 779 |
+
shape = x.size()
|
| 780 |
+
x = x.view(-1, self.dim)
|
| 781 |
+
weights, indices = self.gate(x)
|
| 782 |
+
y = torch.zeros_like(x, dtype=torch.float32)
|
| 783 |
+
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
| 784 |
+
for i in range(self.experts_start_idx, self.experts_end_idx):
|
| 785 |
+
if counts[i] == 0:
|
| 786 |
+
continue
|
| 787 |
+
expert = self.experts[i]
|
| 788 |
+
idx, top = torch.where(indices == i)
|
| 789 |
+
y[idx] += expert(x[idx]) * weights[idx, top, None]
|
| 790 |
+
y += self.shared_experts(x)
|
| 791 |
+
if world_size > 1:
|
| 792 |
+
dist.all_reduce(y)
|
| 793 |
+
return y.type_as(x).view(shape)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class Block(nn.Module):
|
| 797 |
+
"""
|
| 798 |
+
Transformer block combining attention and feed-forward layers.
|
| 799 |
+
|
| 800 |
+
Attributes:
|
| 801 |
+
attn (nn.Module): Attention layer (MLA).
|
| 802 |
+
ffn (nn.Module): Feed-forward network (MLP or MoE).
|
| 803 |
+
attn_norm (nn.Module): Layer normalization for attention.
|
| 804 |
+
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
| 805 |
+
"""
|
| 806 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
| 807 |
+
"""
|
| 808 |
+
Initializes the Transformer block.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
layer_id (int): Layer index in the transformer.
|
| 812 |
+
args (ModelArgs): Model arguments containing block parameters.
|
| 813 |
+
"""
|
| 814 |
+
super().__init__()
|
| 815 |
+
self.attn = MLA(args)
|
| 816 |
+
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
| 817 |
+
self.attn_norm = RMSNorm(args.dim)
|
| 818 |
+
self.ffn_norm = RMSNorm(args.dim)
|
| 819 |
+
|
| 820 |
+
def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
| 821 |
+
"""
|
| 822 |
+
Forward pass for the Transformer block.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
x (torch.Tensor): Input tensor.
|
| 826 |
+
start_pos (int): Starting position in the sequence.
|
| 827 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 828 |
+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
| 829 |
+
|
| 830 |
+
Returns:
|
| 831 |
+
torch.Tensor: Output tensor after block computation.
|
| 832 |
+
"""
|
| 833 |
+
if residual is None:
|
| 834 |
+
x, residual = self.attn_norm(x), x
|
| 835 |
+
else:
|
| 836 |
+
x, residual = self.attn_norm(x, residual)
|
| 837 |
+
x = self.attn(x, start_pos, freqs_cis, mask)
|
| 838 |
+
x, residual = self.ffn_norm(x, residual)
|
| 839 |
+
x = self.ffn(x)
|
| 840 |
+
return x, residual
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
class Transformer(nn.Module):
|
| 844 |
+
"""
|
| 845 |
+
Transformer model with positional embeddings, multiple layers, and output projection.
|
| 846 |
+
|
| 847 |
+
Attributes:
|
| 848 |
+
max_seq_len (int): Maximum sequence length for the transformer.
|
| 849 |
+
embed (nn.Module): Embedding layer for input tokens.
|
| 850 |
+
layers (torch.nn.ModuleList): List of transformer blocks.
|
| 851 |
+
norm (nn.Module): Layer normalization applied after all blocks.
|
| 852 |
+
head (nn.Module): Output projection layer mapping to vocabulary size.
|
| 853 |
+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
| 854 |
+
"""
|
| 855 |
+
def __init__(self, args: ModelArgs):
|
| 856 |
+
"""
|
| 857 |
+
Initializes the Transformer model.
|
| 858 |
+
|
| 859 |
+
Args:
|
| 860 |
+
args (ModelArgs): Model arguments containing transformer parameters.
|
| 861 |
+
"""
|
| 862 |
+
global world_size, rank
|
| 863 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
| 864 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
| 865 |
+
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
| 866 |
+
Linear.scale_fmt = args.scale_fmt
|
| 867 |
+
super().__init__()
|
| 868 |
+
self.max_seq_len = args.max_seq_len
|
| 869 |
+
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
| 870 |
+
self.layers = torch.nn.ModuleList()
|
| 871 |
+
for layer_id in range(args.n_layers):
|
| 872 |
+
self.layers.append(Block(layer_id, args))
|
| 873 |
+
self.norm = RMSNorm(args.dim)
|
| 874 |
+
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
|
| 875 |
+
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
|
| 876 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
| 877 |
+
|
| 878 |
+
@torch.inference_mode()
|
| 879 |
+
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
| 880 |
+
"""
|
| 881 |
+
Forward pass for the Transformer model.
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
|
| 885 |
+
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
|
| 886 |
+
|
| 887 |
+
Returns:
|
| 888 |
+
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
| 889 |
+
"""
|
| 890 |
+
seqlen = tokens.size(1)
|
| 891 |
+
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
| 892 |
+
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
|
| 893 |
+
h, residual = self.embed(tokens), None
|
| 894 |
+
for layer in self.layers:
|
| 895 |
+
h, residual = layer(h, residual, start_pos, freqs_cis, mask)
|
| 896 |
+
h, _ = self.norm(h, residual)
|
| 897 |
+
logits = self.head(h[:, -1].float())
|
| 898 |
+
if world_size > 1:
|
| 899 |
+
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
| 900 |
+
dist.all_gather(all_logits, logits)
|
| 901 |
+
logits = torch.cat(all_logits, dim=-1)
|
| 902 |
+
return logits
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
if __name__ == "__main__":
|
| 906 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 907 |
+
torch.set_default_device("cuda")
|
| 908 |
+
torch.manual_seed(0)
|
| 909 |
+
args = ModelArgs()
|
| 910 |
+
x = torch.randint(0, args.vocab_size, (2, 128))
|
| 911 |
+
model = Transformer(args)
|
| 912 |
+
print(model(x).size())
|
inference/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
safetensors
|
| 4 |
+
fast_hadamard_transform
|
| 5 |
+
tilelang==0.1.6
|