| | |
| | import argparse, json |
| | from safetensors import safe_open |
| | from safetensors.torch import save_file |
| | from pathlib import Path |
| |
|
| | parser = argparse.ArgumentParser(description="Convert original dbrx model into quantizable model") |
| |
|
| | parser.add_argument("--model-dir", type=str, required=True, help="directory to the original dbrx model") |
| | parser.add_argument("--output-dir", type=str, required=True, help="directory for the converted dbrx model") |
| | args = parser.parse_args() |
| |
|
| | model_dir = Path(args.model_dir) |
| | output_dir = Path(args.output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | NUM_EXPERTS = 16 |
| | HIDDEN_SIZE = 6144 |
| | HEAD_DIM = 128 |
| | NUM_KV_HEAD = 8 |
| | FFN_HIDDEN_SIZE = 10752 |
| |
|
| | def change_tensor_attn(tensor): |
| |
|
| | return [x.contiguous() for x in tensor.split([HIDDEN_SIZE, NUM_KV_HEAD*HEAD_DIM, NUM_KV_HEAD*HEAD_DIM])] |
| |
|
| | def change_attn(tensors): |
| |
|
| | keys = list(tensors.keys()) |
| | for k in keys: |
| | if 'Wqkv' in k: |
| | prefix = k.removesuffix('.Wqkv.weight') |
| | tensor = tensors.pop(k) |
| | output_tensor = change_tensor_attn(tensor) |
| | for dtype,t in zip(['q_proj', 'k_proj', 'v_proj'], output_tensor): |
| | tensors[f'{prefix}.{dtype}.weight'] = t |
| | |
| | return tensors |
| |
|
| | def change_tensor_mlp(tensor, reverse=False): |
| |
|
| | output = [x.contiguous() if not reverse else x.t().contiguous() for x in tensor.reshape(NUM_EXPERTS, FFN_HIDDEN_SIZE, HIDDEN_SIZE)] |
| |
|
| | return output |
| |
|
| | def change_mlp(tensors): |
| |
|
| | keys = list(tensors.keys()) |
| | for k in keys: |
| | if any([x in k for x in ['w1', 'v1', 'w2']]): |
| | prefix,dtype = k.rsplit('.', 1) |
| | tensor = tensors.pop(k) |
| | output_tensor = change_tensor_mlp(tensor, dtype=='w2') |
| | for i,t in enumerate(output_tensor): |
| | tensors[f'{prefix}.{i}.{dtype}.weight'] = t |
| |
|
| | return tensors |
| |
|
| | for file in sorted(list(model_dir.glob('*.safetensors'))): |
| | print(file) |
| | tensors = {} |
| | with safe_open(file, 'pt') as f: |
| | metadata = f.metadata() |
| | for k in f.keys(): |
| | tensors[k] = f.get_tensor(k) |
| | tensors = change_attn(tensors) |
| | tensors = change_mlp(tensors) |
| | save_file(tensors, (output_dir / file.name).as_posix(), metadata) |
| |
|
| | with open(model_dir / 'model.safetensors.index.json') as f: |
| | weight_map = json.load(f) |
| |
|
| | weight_keys = list(weight_map['weight_map']) |
| | for k in weight_keys: |
| | if any([x in k for x in ['w1', 'v1', 'w2']]): |
| | prefix,dtype = k.rsplit('.', 1) |
| | value = weight_map['weight_map'].pop(k) |
| | for i in range(NUM_EXPERTS): |
| | weight_map['weight_map'][f'{prefix}.{i}.{dtype}.weight'] = value |
| | elif 'Wqkv' in k: |
| | prefix = k.removesuffix('.Wqkv.weight') |
| | value = weight_map['weight_map'].pop(k) |
| | for dtype in ['q_proj', 'k_proj', 'v_proj']: |
| | weight_map['weight_map'][f'{prefix}.{dtype}.weight'] = value |
| |
|
| | sorted_map = sorted(weight_map['weight_map'].items()) |
| | weight_map['weight_map'] = dict(sorted_map) |
| |
|
| | with open(output_dir / 'model.safetensors.index.json', 'w') as f: |
| | json.dump(weight_map, f, indent=4) |
| |
|
| | |
| | for filename in os.listdir(model_dir): |
| | if filename.endswith(".safetensors") or filename == "model.safetensors.index.json": |
| | continue |
| | src = os.path.join(model_dir, filename) |
| | dst = os.path.join(output_dir, filename) |
| | if os.path.isfile(src): |
| | shutil.copy2(src, dst) |
| |
|