Reduce VRAM consumption by swapping `cuda()` and `to(torch.bfloat16)`

#2
by mingyi456 - opened

When I test the code locally, it appears that converting the weights to bfloat16 only after moving to the GPU causes the excess VRAM to not be freed up (unless maybe torch.cuda.empty_cache() is used, but this is simpler).

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment