qep-1bit-extreme / onebit_linear.py
yishii333's picture
Initial commit
7f8229b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import numpy as np
import math
def pack_sign_bits(sign_tensor: torch.Tensor) -> torch.Tensor:
sign_flat = sign_tensor.flatten()
sign_uint8 = ((sign_flat == 1).to(torch.uint8))
remainder = sign_uint8.numel() % 8
if remainder != 0:
padding = 8 - remainder
sign_uint8 = torch.cat([
sign_uint8,
torch.zeros(padding, dtype=torch.uint8, device=sign_uint8.device)
])
sign_uint8 = sign_uint8.reshape(-1, 8)
shifts = torch.arange(7, -1, -1, device=sign_uint8.device, dtype=torch.uint8)
packed = (sign_uint8 << shifts.unsqueeze(0)).sum(dim=1)
return packed
def unpack_sign_bits_ultra_fast(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
device = packed.device
dtype = torch.float16
int8_tensor = packed.to(torch.int8)
shifts = torch.arange(8, device=device).view(1, 8)
expanded_int8 = int8_tensor.unsqueeze(-1)
unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype)
unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1)
fp16_tensor = -2 * unpacked_bits + 1
if isinstance(original_shape, (tuple, list)):
total_elements = 1
for dim in original_shape:
total_elements *= dim
original_shape = torch.Size(original_shape)
else:
total_elements = original_shape.numel()
return fp16_tensor.flatten()[:total_elements].reshape(original_shape)
def unpack_sign_bits(packed: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
return unpack_sign_bits_ultra_fast(packed, original_shape)
class OneBitLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
a_scale: torch.Tensor = None,
b_scale: torch.Tensor = None,
weight_packed: torch.Tensor = None,
bias: Optional[torch.Tensor] = None,
device=None,
dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
if weight_packed is not None:
expected_size = out_features * in_features // 8
if weight_packed.numel() == expected_size:
weight_2d = weight_packed.view(out_features, in_features // 8).to(torch.int8)
else:
weight_2d = torch.zeros((out_features, in_features // 8), dtype=torch.int8, **factory_kwargs)
self.register_buffer("weight", weight_2d, persistent=False)
else:
self.register_buffer("weight", torch.zeros((out_features, in_features // 8),
dtype=torch.int8, **factory_kwargs), persistent=False)
if a_scale is not None:
self.register_buffer("input_factor", a_scale.to(torch.float16))
else:
self.register_buffer("input_factor", torch.ones(in_features, dtype=torch.float16, **factory_kwargs))
if b_scale is not None:
self.register_buffer("weight_scale", b_scale.to(torch.float16))
else:
self.register_buffer("weight_scale", torch.ones(out_features, dtype=torch.float16, **factory_kwargs))
# Bias
if bias is not None:
self.register_buffer("bias", bias.to(torch.float16))
else:
self.bias = None
self.layernorm = nn.LayerNorm(out_features, elementwise_affine=False, **factory_kwargs)
self._weight_cache = None
def int8_to_fp16(self, int8_tensor):
dtype = self.weight_scale.dtype
shifts = torch.arange(8, device=int8_tensor.device).view(1, 1, 8)
expanded_int8 = int8_tensor.unsqueeze(-1)
unpacked_bits = ((expanded_int8 >> shifts) & 1).to(dtype)
unpacked_bits = unpacked_bits.view(int8_tensor.shape[0], -1)
fp16_tensor = -2 * unpacked_bits + 1
return fp16_tensor
def forward(self, input):
input_factor_shape = [1] * len(input.shape)
input_factor_shape[-1] = self.in_features
input = input * self.input_factor.view(*input_factor_shape)
if self._weight_cache is not None:
weight = self._weight_cache
else:
weight = self.int8_to_fp16(self.weight)
self._weight_cache = weight
output = F.linear(input, weight)
weight_scale_shape = [1] * len(output.shape)
weight_scale_shape[-1] = self.out_features
output *= self.weight_scale.view(*weight_scale_shape)
output = self.layernorm(output)
if self.bias is not None:
output += self.bias
return output
@classmethod
def from_safetensors(cls, state_dict: dict, layer_idx: int, module_name: str):
prefix = f"model.layers.{layer_idx}.{module_name}"
input_factor_key = f"{prefix}.input_factor"
weight_scale_key = f"{prefix}.weight_scale"
weight_key = f"{prefix}.weight"
bias_key = f"{prefix}.bias"
input_factor = None
if input_factor_key in state_dict:
input_factor = state_dict[input_factor_key]
elif f"{prefix}.a_scale" in state_dict:
input_factor = state_dict[f"{prefix}.a_scale"]
weight_scale = None
if weight_scale_key in state_dict:
weight_scale = state_dict[weight_scale_key]
elif f"{prefix}.b_scale" in state_dict:
weight_scale = state_dict[f"{prefix}.b_scale"]
weight_packed = None
if weight_key in state_dict:
weight_packed = state_dict[weight_key]
elif f"{prefix}.sign_packed" in state_dict:
weight_packed = state_dict[f"{prefix}.sign_packed"]
bias = state_dict.get(bias_key)
if input_factor is None or weight_scale is None:
return None
in_features = input_factor.shape[0]
out_features = weight_scale.shape[0]
return cls(
in_features=in_features,
out_features=out_features,
a_scale=input_factor,
b_scale=weight_scale,
weight_packed=weight_packed,
bias=bias
)