ritish369's picture
Upload folder using huggingface_hub
c60906c verified
import tempfile
from pathlib import Path
from typing import Dict
import pytest
import torch
from finetune.args import LoraArgs
from finetune.checkpointing import Checkpointer
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
downcast_mixed_precision,
prepare_mixed_precision,
upcast_mixed_precision,
)
from finetune.utils import TrainState
from finetune.wrapped_model import load_model
from model.transformer import (
LoRALinear,
)
from tests.test_utils import (
MODEL_PATH,
get_dataloader,
is_float_equal,
setup_mp_test_dist,
)
from .test_utils import spawn_for_all_world_sizes
torch.backends.cudnn.deterministic = True # use deterministic algorithms
torch.backends.cudnn.benchmark = False # disable cuDNN benchmark
@pytest.mark.parametrize(
("world_size", "enable_lora", "dtype"),
[
(1, False, torch.float32),
(1, True, torch.float32),
(2, False, torch.float32),
(2, True, torch.float32),
(1, False, torch.bfloat16),
(1, True, torch.bfloat16),
(2, False, torch.bfloat16),
(2, True, torch.bfloat16),
],
)
def test_weights_loading(world_size, enable_lora, dtype):
spawn_for_all_world_sizes(
_check_weights_loading,
world_sizes=[world_size],
args=[enable_lora, dtype],
deterministic=True,
)
def _check_weights_loading(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
enable_lora: bool,
dtype: torch.dtype,
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=dtype,
)
# add hook so that LoRA weights are automatically merged:
def register_merge_lora_hook(m: torch.nn.Module):
def merge_lora(
m: torch.nn.Module, destination: Dict[str, torch.Tensor], prefix: str, *args
):
weight = m.merge_weight()
destination[prefix + "weight"] = weight
if isinstance(m, LoRALinear):
m._merge_lora_handle = m._register_state_dict_hook(merge_lora)
model.apply(register_merge_lora_hook)
if world_size > 1:
with model.summon_full_params(model, writeback=True):
states = {
k: v
for k, v in model.state_dict().items()
if "lora" not in k and "frozen" not in k
}
else:
states = {
k: v
for k, v in model.state_dict().items()
if "lora" not in k and "frozen" not in k
}
EXP_PARAM_SUM = 308.9932 if dtype == torch.float32 else 308.0
params = sum([v.sum() for v in states.values()]).item()
# LoRA is equal to no LoRA as LoRA weights should be init to 0
assert is_float_equal(params, EXP_PARAM_SUM), params
if enable_lora:
lora_B_params = [
v.float().abs().sum() for k, v in model.named_parameters() if "lora_B" in k
]
assert len(lora_B_params) > 0
assert sum(lora_B_params) == 0, "Lora_B should always be zero init"
lora_A_params = [
v.float().abs().sum() for k, v in model.named_parameters() if "lora_A" in k
]
assert len(lora_A_params) > 0
assert sum(lora_A_params) > 0, "Lora_A should init to non-zero values"
@pytest.mark.parametrize(
("world_size", "enable_lora"), [(1, False), (1, True), (2, False), (2, True)]
)
def test_fsdp_logits_and_loss(world_size, enable_lora):
spawn_for_all_world_sizes(
_check_fsdp_logits_and_loss,
world_sizes=[world_size],
args=[enable_lora],
deterministic=True,
)
def _check_fsdp_logits_and_loss(
rank: int, world_size: int, filename: str, filename_rpc: str, enable_lora: bool
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 100
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=torch.bfloat16,
)
# By setting equal rank and world_size we can assure that both ranks see the same data and hence the average
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
# check logits
# logits should be the same for LoRA and non-LoRA
assert output.shape == (seq_len, model.args.vocab_size)
output_sum = output.abs().float().sum().item()
EXP_OUTPUT_WORLD_1 = 162617.625
assert is_float_equal(output_sum, EXP_OUTPUT_WORLD_1, precision=1e1), output_sum
# check loss is the same for all
# loss should be the same for LoRA and non-LoRA
mb_loss = compute_loss_with_mask(output, y, y_mask)
EXPECTED_LOSS = 10.408413887023926
assert is_float_equal(mb_loss.item(), EXPECTED_LOSS), mb_loss.item()
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_fsdp_grads_non_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_fsdp_grads_non_lora,
world_sizes=[world_size],
deterministic=True,
args=[dtype],
)
def _check_fsdp_grads_non_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 2048
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=False),
checkpoint=True,
param_dtype=dtype,
)
# same world_size to check for equality
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
num_grad_params = sum([p.grad.numel() for p in model.parameters()])
assert (4301120 // world_size) == num_grad_params, num_grad_params
torch.distributed.barrier()
sharded_flat_grads = sum(
[p.grad.float().abs().sum().item() for p in model.parameters()]
)
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}")
EXP_GRAD_WORLD_2_RANK_0 = 95.45827150344849
EXP_GRAD_WORLD_2_RANK_1 = 86.09188461303711
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1
if world_size == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 0:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1
), sharded_flat_grads
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_fsdp_grads_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_fsdp_grads_lora,
world_sizes=[world_size],
deterministic=True,
args=[dtype],
)
def _check_fsdp_grads_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 2048
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=True),
checkpoint=True,
param_dtype=dtype,
)
# same world_size to check for equality
data_loader = get_dataloader(seq_len=seq_len, rank=0, world_size=2)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
num_grad_params = sum(
[p.grad.numel() for p in model.parameters() if p.grad is not None]
)
assert (40960 // world_size) == num_grad_params, num_grad_params
torch.distributed.barrier()
sharded_flat_grads = sum(
[
p.grad.float().abs().sum().item()
for p in model.parameters()
if p.grad is not None
]
)
print(f"{rank}: {world_size}: {dtype} = {sharded_flat_grads}")
EXP_GRAD_WORLD_2_RANK_0 = 3.0742580661177635
EXP_GRAD_WORLD_2_RANK_1 = 3.074301045779139
EXP_GRAD_WORLD_1 = EXP_GRAD_WORLD_2_RANK_0 + EXP_GRAD_WORLD_2_RANK_1
if world_size == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_1, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 0:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_0, 2.0e-1
), sharded_flat_grads
elif world_size == 2 and rank == 1:
assert is_float_equal(
sharded_flat_grads, EXP_GRAD_WORLD_2_RANK_1, 2.0e-1
), sharded_flat_grads
@pytest.mark.parametrize(
("world_size", "dtype"),
[(1, torch.bfloat16), (2, torch.bfloat16), (1, torch.float32), (2, torch.float32)],
)
def test_grad_update_lora(world_size, dtype):
spawn_for_all_world_sizes(
_check_grad_update_lora,
world_sizes=[world_size],
args=[dtype],
deterministic=True,
)
def _check_grad_update_lora(
rank: int, world_size: int, filename: str, filename_rpc: str, dtype: torch.dtype
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 1000
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=True),
checkpoint=True,
param_dtype=dtype,
)
optimizer = torch.optim.AdamW(model.parameters())
data_loader = get_dataloader(seq_len=seq_len)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda(non_blocking=True)
y = torch.from_numpy(batch.y).cuda(non_blocking=True)
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
lora_weight_sum = 0
non_lora_weight_sum = 0
for name, param in model.named_parameters():
if "lora" in name or "norm" in name:
assert param.grad is not None, name
lora_weight_sum += param.data.float().abs().sum()
else:
assert param.grad is None, name
non_lora_weight_sum += param.data.float().abs().sum()
# update weights
optimizer.step()
new_lora_weight_sum = 0
new_non_lora_weight_sum = 0
for name, param in model.named_parameters():
if "lora" in name or "norm" in name:
assert param.grad is not None, name
new_lora_weight_sum += param.data.float().abs().sum()
else:
assert param.grad is None, name
new_non_lora_weight_sum += param.data.float().abs().sum()
# make sure that LoRA weights changed, but non-LoRA weights stayed the same
assert not is_float_equal(
new_lora_weight_sum, lora_weight_sum, 1e-4
), f"New: {new_lora_weight_sum}, Old: {lora_weight_sum}"
assert is_float_equal(
new_non_lora_weight_sum, non_lora_weight_sum, 1e-4
), f"New: {new_non_lora_weight_sum}, Old: {non_lora_weight_sum}"
@pytest.mark.parametrize(
("enable_lora", "param_dtype"),
[
(False, torch.float32),
(True, torch.float32),
(False, torch.bfloat16),
(True, torch.bfloat16),
],
)
def test_grads_fsdp_mp(enable_lora, param_dtype):
with tempfile.TemporaryDirectory() as tmpdirname:
for world_size in [1, 2]:
spawn_for_all_world_sizes(
_check_grads_fsdp_mp,
world_sizes=[world_size],
deterministic=True,
args=[tmpdirname, enable_lora, param_dtype],
)
w1_sd = torch.load(Path(tmpdirname) / Path("params_w1.pt"), map_location="cpu")
w2_sd = torch.load(Path(tmpdirname) / Path("params_w2.pt"), map_location="cpu")
for k in w1_sd.keys():
assert w1_sd[k].shape == w2_sd[k].shape, k
atol = 10 if param_dtype == torch.float32 else 100
assert (w1_sd[k] - w2_sd[k]).sum().abs().item() < atol
def _check_grads_fsdp_mp(
rank: int,
world_size: int,
filename: str,
filename_rpc: str,
tmpdirname: str,
enable_lora: bool,
param_dtype: torch.dtype,
):
model_parallel = 1
setup_mp_test_dist(rank, world_size, filename, model_parallel, seed=0)
seq_len = 4096
optim_dtype = torch.float32
folder = Path(MODEL_PATH)
model = load_model(
folder=folder,
lora=LoraArgs(enable=enable_lora),
checkpoint=True,
param_dtype=param_dtype,
)
# high learning rate to show differences
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)
# mock a train state that has done three steps
steps = 4
state = TrainState(max_steps=steps)
# mock run_dir as we won't save anything in this test
run_dir = Path(tmpdirname)
checkpointer = Checkpointer(model, state, run_dir=run_dir, num_ckpt_keep=None)
# make sure the same data is seen
dataloaders = [
get_dataloader(seq_len=seq_len, rank=rank + i, world_size=2)
for i in range(2 - world_size + 1)
]
prepare_mixed_precision(
model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
)
for _ in range(steps):
state.start_step()
optimizer.zero_grad()
for data_loader in dataloaders:
torch.manual_seed(0)
batch = next(data_loader)
x = torch.from_numpy(batch.x).cuda()
y = torch.from_numpy(batch.y).cuda()
y_mask = (
torch.from_numpy(batch.y_mask).cuda(non_blocking=True)
if batch.y_mask is not None
else None
)
# forward / backward
output = model(
input_ids=x,
seqlens=batch.sizes,
)
mb_loss = compute_loss_with_mask(output, y, y_mask)
mb_loss.backward()
assert model.params[0].dtype == param_dtype
print(f"rank: {rank}, world_size: {world_size}, x: {x.abs().sum()}")
print(f"rank: {rank}, world_size: {world_size}, y: {y.abs().sum()}")
print(f"rank: {rank}, world_size: {world_size}, x shape: {x.shape}")
if y_mask is not None:
print(
f"rank: {rank}, world_size: {world_size}, y_mask: {y_mask.abs().sum()}"
)
print(f"rank: {rank}, world_size: {world_size}, loss: {mb_loss}")
for p in model.parameters():
if p.requires_grad:
assert p.grad is not None
p.grad.div_(len(dataloaders))
max_norm = 1.0
model.clip_grad_norm_(max_norm=max_norm)
upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
optimizer.step()
downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)
save_dict = checkpointer.retrieve_save_states(
save_only_lora=enable_lora, save_dtype=torch.float32
)
path = "params_w1.pt" if world_size == 1 else "params_w2.pt"
torch.save(save_dict, Path(tmpdirname) / Path(path))