Spaces:
Runtime error
Runtime error
nanushio
commited on
Commit
·
0c18aca
1
Parent(s):
c45a2ea
- [MINOR] [SOURCE] [UPDATE] 1. update app.py
Browse files- app.py +3 -1
- cover/datasets/cover_datasets.py +32 -17
app.py
CHANGED
|
@@ -66,8 +66,10 @@ def inference_one_video(input_video):
|
|
| 66 |
"""
|
| 67 |
TESTING
|
| 68 |
"""
|
|
|
|
|
|
|
| 69 |
views, _ = spatial_temporal_view_decomposition(
|
| 70 |
-
|
| 71 |
)
|
| 72 |
|
| 73 |
for k, v in views.items():
|
|
|
|
| 66 |
"""
|
| 67 |
TESTING
|
| 68 |
"""
|
| 69 |
+
# Convert input video to tensor and adjust dimensions
|
| 70 |
+
input_video_tensor = torch.from_numpy(input_video).permute(0, 3, 1, 2)
|
| 71 |
views, _ = spatial_temporal_view_decomposition(
|
| 72 |
+
input_video_tensor, dopt["sample_types"], temporal_samplers
|
| 73 |
)
|
| 74 |
|
| 75 |
for k, v in views.items():
|
cover/datasets/cover_datasets.py
CHANGED
|
@@ -232,34 +232,49 @@ def spatial_temporal_view_decomposition(
|
|
| 232 |
video_path, sample_types, samplers, is_train=False, augment=False,
|
| 233 |
):
|
| 234 |
video = {}
|
| 235 |
-
if
|
| 236 |
-
print("This part will be deprecated due to large memory cost.")
|
| 237 |
-
## This is only an adaptation to LIVE-Qualcomm
|
| 238 |
-
ovideo = skvideo.io.vread(
|
| 239 |
-
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
|
| 240 |
-
)
|
| 241 |
-
for stype in samplers:
|
| 242 |
-
frame_inds = samplers[stype](ovideo.shape[0], is_train)
|
| 243 |
-
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
|
| 244 |
-
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
| 245 |
-
del ovideo
|
| 246 |
-
else:
|
| 247 |
-
decord.bridge.set_bridge("torch")
|
| 248 |
-
vreader = VideoReader(video_path)
|
| 249 |
-
### Avoid duplicated video decoding!!! Important!!!!
|
| 250 |
all_frame_inds = []
|
| 251 |
frame_inds = {}
|
| 252 |
for stype in samplers:
|
| 253 |
-
frame_inds[stype] = samplers[stype](
|
| 254 |
all_frame_inds.append(frame_inds[stype])
|
| 255 |
|
| 256 |
### Each frame is only decoded one time!!!
|
| 257 |
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
| 258 |
-
frame_dict = {idx:
|
| 259 |
|
| 260 |
for stype in samplers:
|
| 261 |
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
| 262 |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
sampled_video = {}
|
| 265 |
for stype, sopt in sample_types.items():
|
|
|
|
| 232 |
video_path, sample_types, samplers, is_train=False, augment=False,
|
| 233 |
):
|
| 234 |
video = {}
|
| 235 |
+
if torch.is_tensor(video_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
all_frame_inds = []
|
| 237 |
frame_inds = {}
|
| 238 |
for stype in samplers:
|
| 239 |
+
frame_inds[stype] = samplers[stype](video_path.shape[0], is_train)
|
| 240 |
all_frame_inds.append(frame_inds[stype])
|
| 241 |
|
| 242 |
### Each frame is only decoded one time!!!
|
| 243 |
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
| 244 |
+
frame_dict = {idx: video_path[idx].permute(1, 2, 0) for idx in np.unique(all_frame_inds)}
|
| 245 |
|
| 246 |
for stype in samplers:
|
| 247 |
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
| 248 |
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
| 249 |
+
else:
|
| 250 |
+
if video_path.endswith(".yuv"):
|
| 251 |
+
print("This part will be deprecated due to large memory cost.")
|
| 252 |
+
## This is only an adaptation to LIVE-Qualcomm
|
| 253 |
+
ovideo = skvideo.io.vread(
|
| 254 |
+
video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
|
| 255 |
+
)
|
| 256 |
+
for stype in samplers:
|
| 257 |
+
frame_inds = samplers[stype](ovideo.shape[0], is_train)
|
| 258 |
+
imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds]
|
| 259 |
+
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
| 260 |
+
del ovideo
|
| 261 |
+
else:
|
| 262 |
+
decord.bridge.set_bridge("torch")
|
| 263 |
+
vreader = VideoReader(video_path)
|
| 264 |
+
### Avoid duplicated video decoding!!! Important!!!!
|
| 265 |
+
all_frame_inds = []
|
| 266 |
+
frame_inds = {}
|
| 267 |
+
for stype in samplers:
|
| 268 |
+
frame_inds[stype] = samplers[stype](len(vreader), is_train)
|
| 269 |
+
all_frame_inds.append(frame_inds[stype])
|
| 270 |
+
|
| 271 |
+
### Each frame is only decoded one time!!!
|
| 272 |
+
all_frame_inds = np.concatenate(all_frame_inds, 0)
|
| 273 |
+
frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)}
|
| 274 |
+
|
| 275 |
+
for stype in samplers:
|
| 276 |
+
imgs = [frame_dict[idx] for idx in frame_inds[stype]]
|
| 277 |
+
video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2)
|
| 278 |
|
| 279 |
sampled_video = {}
|
| 280 |
for stype, sopt in sample_types.items():
|