| import math | |
| import argparse | |
| import torch | |
| import random | |
| from eval_utils import get_test_dataset | |
| from .modeling_bitnet import BitnetForCausalLM | |
| from .tokenization_bitnet import BitnetTokenizer | |
| from tqdm import tqdm | |
| torch.set_grad_enabled(False) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--seed', default=0, type=int) | |
| parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) | |
| parser.add_argument('--seqlen', default=2048, type=int) | |
| def calulate_loss(model, input, loss_fct): | |
| output = model(input, | |
| use_cache=False, | |
| output_hidden_states=False, | |
| output_attentions=False)[0] | |
| shift_logits = output[:, :-1, :].contiguous() | |
| shift_labels = input[:, 1:] | |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| return loss | |
| def main(args): | |
| datasets = ['c4', 'wikitext2'] | |
| model = BitnetForCausalLM.from_pretrained( | |
| args.hf_path, | |
| device_map='auto', | |
| low_cpu_mem_usage=True, | |
| use_flash_attention_2=True, | |
| torch_dtype=torch.float16, | |
| ).half() | |
| tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False) | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda() | |
| ppl = [] | |
| for dataset in datasets: | |
| testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen) | |
| acc_loss, count = 0.0, 0 | |
| progress = tqdm(range(len(testdata))) | |
| for ii in progress: | |
| input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1) | |
| loss = calulate_loss(model, input, loss_fct) | |
| count += (input.size(-1) - 1) | |
| acc_loss += loss.item() | |
| progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}") | |
| avg_loss = acc_loss / count / math.log(2) | |
| ppl.append(2 ** avg_loss) | |
| print("{} PPL: {}".format(dataset, ppl[-1])) | |
| print(ppl) | |
| print("Avg PPL:", sum(ppl) / len(ppl)) | |
| if __name__ == '__main__': | |
| torch.set_grad_enabled(False) | |
| args = parser.parse_args() | |
| random.seed(args.seed) | |
| torch.random.manual_seed(args.seed) | |
| main(args) |