Update resampler.py (#3)
Browse files- Update resampler.py (cc31e963f29bb5e66f0d399572ba0650e7fc92d4)
Co-authored-by: ed <[email protected]>
- resampler.py +9 -0
resampler.py
CHANGED
|
@@ -160,6 +160,15 @@ class Resampler(nn.Module):
|
|
| 160 |
nn.init.constant_(m.bias, 0)
|
| 161 |
nn.init.constant_(m.weight, 1.0)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def forward(self, x, tgt_sizes=None, temporal_ids=None):
|
| 164 |
assert x.shape[0] == tgt_sizes.shape[0]
|
| 165 |
bs = x.shape[0]
|
|
|
|
| 160 |
nn.init.constant_(m.bias, 0)
|
| 161 |
nn.init.constant_(m.weight, 1.0)
|
| 162 |
|
| 163 |
+
def _initialize_weights(self, m):
|
| 164 |
+
if isinstance(m, nn.Linear):
|
| 165 |
+
trunc_normal_(m.weight, std=.02)
|
| 166 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 167 |
+
nn.init.constant_(m.bias, 0)
|
| 168 |
+
elif isinstance(m, nn.LayerNorm):
|
| 169 |
+
nn.init.constant_(m.bias, 0)
|
| 170 |
+
nn.init.constant_(m.weight, 1.0)
|
| 171 |
+
|
| 172 |
def forward(self, x, tgt_sizes=None, temporal_ids=None):
|
| 173 |
assert x.shape[0] == tgt_sizes.shape[0]
|
| 174 |
bs = x.shape[0]
|