|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Tuple |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
|
|
|
def pack_sign_bits(sign_tensor: torch.Tensor) -> torch.Tensor: |
|
|
sign_flat = sign_tensor.flatten() |
|
|
sign_uint8 = ((sign_flat == 1).to(torch.uint8)) |
|
|
|
|
|
|
|
|
remainder = sign_uint8.numel() % 8 |
|
|
if remainder != 0: |
|
|
padding = 8 - remainder |
|
|
sign_uint8 = torch.cat([ |
|
|
sign_uint8, |
|
|
torch.zeros(padding, dtype=torch.uint8, device=sign_uint8.device) |
|
|
]) |
|
|
|
|
|
|
|
|
sign_uint8 = sign_uint8.reshape(-1, 8) |
|
|
shifts = torch.arange(7, -1, -1, device=sign_uint8.device, dtype=torch.uint8) |
|
|
packed = (sign_uint8 << shifts.unsqueeze(0)).sum(dim=1) |
|
|
|
|
|
return packed |
|
|
|
|
|
|
|
|
def unpack_sign_bits_ultra_fast(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor: |
|
|
device = packed.device |
|
|
dtype = torch.float16 |
|
|
|
|
|
|
|
|
int8_tensor = packed.to(torch.int8) |
|
|
|
|
|
|
|
|
shifts = torch.arange(8, device=device).view(1, 8) |
|
|
expanded_int8 = int8_tensor.unsqueeze(-1) |
|
|
|
|
|
|
|
|
unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype) |
|
|
unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1) |
|
|
|
|
|
|
|
|
fp16_tensor = -2 * unpacked_bits + 1 |
|
|
|
|
|
|
|
|
if isinstance(original_shape, (tuple, list)): |
|
|
total_elements = 1 |
|
|
for dim in original_shape: |
|
|
total_elements *= dim |
|
|
original_shape = torch.Size(original_shape) |
|
|
else: |
|
|
total_elements = original_shape.numel() |
|
|
|
|
|
return fp16_tensor.flatten()[:total_elements].reshape(original_shape) |
|
|
|
|
|
|
|
|
def unpack_sign_bits(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor: |
|
|
return unpack_sign_bits_ultra_fast(packed, original_shape) |
|
|
|
|
|
|
|
|
class OneBitLinear(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
a_scale: torch.Tensor = None, |
|
|
b_scale: torch.Tensor = None, |
|
|
weight_packed: torch.Tensor = None, |
|
|
bias: Optional[torch.Tensor] = None, |
|
|
device=None, |
|
|
dtype=None): |
|
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
|
super().__init__() |
|
|
|
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
|
|
|
|
|
|
if weight_packed is not None: |
|
|
|
|
|
expected_size = out_features * in_features // 8 |
|
|
if weight_packed.numel() == expected_size: |
|
|
weight_2d = weight_packed.view(out_features, in_features // 8).to(torch.int8) |
|
|
else: |
|
|
|
|
|
weight_2d = torch.zeros((out_features, in_features // 8), dtype=torch.int8, **factory_kwargs) |
|
|
self.register_buffer("weight", weight_2d, persistent=False) |
|
|
else: |
|
|
|
|
|
self.register_buffer("weight", torch.zeros((out_features, in_features // 8), |
|
|
dtype=torch.int8, **factory_kwargs), persistent=False) |
|
|
|
|
|
|
|
|
if a_scale is not None: |
|
|
self.register_buffer("input_factor", a_scale.to(torch.float16)) |
|
|
else: |
|
|
self.register_buffer("input_factor", torch.ones(in_features, dtype=torch.float16, **factory_kwargs)) |
|
|
|
|
|
if b_scale is not None: |
|
|
self.register_buffer("weight_scale", b_scale.to(torch.float16)) |
|
|
else: |
|
|
self.register_buffer("weight_scale", torch.ones(out_features, dtype=torch.float16, **factory_kwargs)) |
|
|
|
|
|
|
|
|
if bias is not None: |
|
|
self.register_buffer("bias", bias.to(torch.float16)) |
|
|
else: |
|
|
self.bias = None |
|
|
|
|
|
|
|
|
self.layernorm = nn.LayerNorm(out_features, elementwise_affine=False, **factory_kwargs) |
|
|
|
|
|
|
|
|
self._weight_cache = None |
|
|
|
|
|
def int8_to_fp16(self, int8_tensor): |
|
|
dtype = self.weight_scale.dtype |
|
|
shifts = torch.arange(8, device=int8_tensor.device).view(1, 1, 8) |
|
|
|
|
|
expanded_int8 = int8_tensor.unsqueeze(-1) |
|
|
|
|
|
|
|
|
unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype) |
|
|
unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1) |
|
|
|
|
|
|
|
|
fp16_tensor = -2 * unpacked_bits + 1 |
|
|
return fp16_tensor |
|
|
|
|
|
def forward(self, input): |
|
|
input_factor_shape = [1] * len(input.shape) |
|
|
input_factor_shape[-1] = self.in_features |
|
|
input = input * self.input_factor.view(*input_factor_shape) |
|
|
|
|
|
|
|
|
if self._weight_cache is not None: |
|
|
weight = self._weight_cache |
|
|
else: |
|
|
weight = self.int8_to_fp16(self.weight) |
|
|
self._weight_cache = weight |
|
|
|
|
|
|
|
|
output = F.linear(input, weight) |
|
|
|
|
|
|
|
|
weight_scale_shape = [1] * len(output.shape) |
|
|
weight_scale_shape[-1] = self.out_features |
|
|
output *= self.weight_scale.view(*weight_scale_shape) |
|
|
|
|
|
|
|
|
output = self.layernorm(output) |
|
|
|
|
|
|
|
|
if self.bias is not None: |
|
|
output += self.bias |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_safetensors(cls, state_dict: dict, layer_idx: int, module_name: str): |
|
|
prefix = f"model.layers.{layer_idx}.{module_name}" |
|
|
|
|
|
|
|
|
input_factor_key = f"{prefix}.input_factor" |
|
|
weight_scale_key = f"{prefix}.weight_scale" |
|
|
weight_key = f"{prefix}.weight" |
|
|
bias_key = f"{prefix}.bias" |
|
|
|
|
|
|
|
|
input_factor = None |
|
|
if input_factor_key in state_dict: |
|
|
input_factor = state_dict[input_factor_key] |
|
|
elif f"{prefix}.a_scale" in state_dict: |
|
|
input_factor = state_dict[f"{prefix}.a_scale"] |
|
|
|
|
|
weight_scale = None |
|
|
if weight_scale_key in state_dict: |
|
|
weight_scale = state_dict[weight_scale_key] |
|
|
elif f"{prefix}.b_scale" in state_dict: |
|
|
weight_scale = state_dict[f"{prefix}.b_scale"] |
|
|
weight_packed = None |
|
|
if weight_key in state_dict: |
|
|
weight_packed = state_dict[weight_key] |
|
|
elif f"{prefix}.sign_packed" in state_dict: |
|
|
weight_packed = state_dict[f"{prefix}.sign_packed"] |
|
|
bias = state_dict.get(bias_key) |
|
|
|
|
|
if input_factor is None or weight_scale is None: |
|
|
return None |
|
|
|
|
|
|
|
|
in_features = input_factor.shape[0] |
|
|
out_features = weight_scale.shape[0] |
|
|
|
|
|
return cls( |
|
|
in_features=in_features, |
|
|
out_features=out_features, |
|
|
a_scale=input_factor, |
|
|
b_scale=weight_scale, |
|
|
weight_packed=weight_packed, |
|
|
bias=bias |
|
|
) |
|
|
|