fix NormHead eval
#8
by
kuaizhirui
- opened
- modeling_baichuan.py +2 -1
modeling_baichuan.py
CHANGED
|
@@ -502,9 +502,10 @@ class NormHead(nn.Module):
|
|
| 502 |
def forward(self, hidden_states):
|
| 503 |
if self.training:
|
| 504 |
norm_weight = nn.functional.normalize(self.weight)
|
|
|
|
| 505 |
elif self.first_flag:
|
| 506 |
self.first_flag = False
|
| 507 |
-
self.weight = nn.
|
| 508 |
norm_weight = self.weight
|
| 509 |
else:
|
| 510 |
norm_weight = self.weight
|
|
|
|
| 502 |
def forward(self, hidden_states):
|
| 503 |
if self.training:
|
| 504 |
norm_weight = nn.functional.normalize(self.weight)
|
| 505 |
+
self.first_flag = False
|
| 506 |
elif self.first_flag:
|
| 507 |
self.first_flag = False
|
| 508 |
+
self.weight.data = nn.functional.normalize(self.weight)
|
| 509 |
norm_weight = self.weight
|
| 510 |
else:
|
| 511 |
norm_weight = self.weight
|