DraconicDragon commited on
Commit
486419f
·
1 Parent(s): 25ee244

Create save_ema.py

Browse files
Files changed (1) hide show
  1. save_ema.py +14 -0
save_ema.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ input_path = "best_checkpoint.pth"
3
+ output_path = "best_checkpoint_ema.pth"
4
+
5
+ state = torch.load(input_path, map_location="cpu", weights_only=False)
6
+
7
+ ema_state = state["model_ema"]
8
+
9
+ if hasattr(ema_state, "state_dict"):
10
+ ema_state = ema_state.state_dict()
11
+
12
+ torch.save(ema_state, output_path)
13
+
14
+ print(f"saved EMA weights to {output_path}")