mingyi456 commited on
Commit
62ee1bb
·
verified ·
1 Parent(s): 20b35e8

Reduce VRAM consumption by swapping `cuda()` and `to(torch.bfloat16)`

Browse files

When I test the code locally, it appears that converting the weights to bfloat16 only after moving to the GPU causes the excess VRAM to not be freed up (unless maybe `torch.cuda.empty_cache()` is used, but this is simpler).

Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -43,7 +43,7 @@ def process_ocr_task(image, model_size, task_type, ref_text):
43
  return "Please upload an image first.", None
44
 
45
  print("🚀 Moving model to GPU...")
46
- model_gpu = model.cuda().to(torch.bfloat16)
47
  print("✅ Model is on GPU.")
48
 
49
  with tempfile.TemporaryDirectory() as output_path:
 
43
  return "Please upload an image first.", None
44
 
45
  print("🚀 Moving model to GPU...")
46
+ model_gpu = model.to(torch.bfloat16).cuda()
47
  print("✅ Model is on GPU.")
48
 
49
  with tempfile.TemporaryDirectory() as output_path: