File size: 5,021 Bytes
2595c46
 
 
 
 
 
1e407f0
2595c46
1e407f0
2595c46
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
 
 
 
 
 
 
 
 
 
2595c46
 
 
 
 
 
 
 
1e407f0
 
 
 
 
 
 
 
 
 
 
2595c46
 
 
 
 
 
 
 
1e407f0
 
 
 
 
2595c46
 
 
 
 
 
 
 
1e407f0
2595c46
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
 
 
 
2595c46
 
 
 
 
 
 
 
1e407f0
 
2595c46
 
 
 
 
 
 
 
1e407f0
2595c46
 
1e407f0
 
2595c46
 
 
 
 
 
 
 
 
 
 
 
1e407f0
 
2595c46
 
 
 
 
 
 
 
1e407f0
2595c46
 
1e407f0
 
2595c46
 
 
 
 
 
 
 
 
1e407f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#define CUB_IGNORE_DEPRECATED_API

#undef CUB_WRAPPED_NAMESPACE
#define CUB_WRAPPED_NAMESPACE megablocks

#include "new_cumsum.h"
#include "gpu_backend.h"

#include <cstdint>

namespace megablocks {

struct Inclusive {};
struct Exclusive {};

template <typename Type> struct Cumsum {

  template<
    typename InputIteratorT,
    typename OutputIteratorT>
  static void Run(void * d_temp_storage,
                  size_t & temp_storage_bytes,
                  InputIteratorT d_in,
                  OutputIteratorT d_out,
                  int num_items,
                  megablocks::gpuStream_t stream = 0) {
    GPU_CALL(cubns::DeviceScan::ExclusiveSum(d_temp_storage,
                                             temp_storage_bytes,
                                             d_in,
                                             d_out,
                                             num_items,
                                             stream));
  }
};

template <> struct Cumsum<Inclusive> {
  template<
    typename InputIteratorT,
    typename OutputIteratorT>
  static void Run(void * d_temp_storage,
                  size_t & temp_storage_bytes,
                  InputIteratorT d_in,
                  OutputIteratorT d_out,
                  int num_items,
                  megablocks::gpuStream_t stream = 0) {
    GPU_CALL(cubns::DeviceScan::InclusiveSum(d_temp_storage,
                                             temp_storage_bytes,
                                             d_in,
                                             d_out,
                                             num_items,
                                             stream));
  }
};

template <typename SumType, typename T>
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Get temporary storage size.
  size_t scratchpad_bytes = 0;
  Cumsum<SumType>::Run(nullptr,
                       scratchpad_bytes,
                       x.data_ptr<T>(),
                       out.data_ptr<T>(),
                       x.size(1),
                       megablocks::get_current_stream());

  // Allocate scratchpad.
  //
  // NOTE: We scale for the batch dimension so we can run in parallel.
  auto options = torch::TensorOptions()
    .dtype(torch::kInt8)
    .device(x.device());
  torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
                                          options);

  // Run the kernel.
  //
  // NOTE: Using different streams for each issue does not appear to
  // yield performance gains for our problem set. The overhead of
  // event/stream synchronization appears to outweigh the benfits.
  // We could write a true batched cumsum, but this would require
  // significant code duplication from cub and we might move away
  // from this formulation anyways.
  for (int i = 0; i < x.size(0); ++i) {
    void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
    Cumsum<SumType>::Run(scratchpad_ptr,
                         scratchpad_bytes,
                         x.data_ptr<T>() + x.size(1) * i,
                         out.data_ptr<T>() + x.size(1) * i,
                         x.size(1),
                         megablocks::get_current_stream());
  }
}

void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Validate the input matrix.
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
              x.scalar_type() == torch::kInt32 ||
              x.scalar_type() == torch::kInt64);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == x.scalar_type());

  // NOTE: We currently only support contraction across the contiguous
  // dimension in the matrix.
  TORCH_CHECK(dim == 1);

  if (x.scalar_type() == torch::kInt16) {
    cub_cumsum<Exclusive, short>(x, dim, out);
    return;
  }
  if (x.scalar_type() == torch::kInt32) {
    cub_cumsum<Exclusive, int>(x, dim, out);
    return;
  }
  TORCH_CHECK(x.scalar_type() == torch::kInt64);
  cub_cumsum<Exclusive, long>(x, dim, out);
}

void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Validate the input matrix.
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
              x.scalar_type() == torch::kInt32 ||
              x.scalar_type() == torch::kInt64);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == x.scalar_type());

  // NOTE: We currently only support contraction across the contiguous
  // dimension in the matrix.
  TORCH_CHECK(dim == 1);

  if (x.scalar_type() == torch::kInt16) {
    cub_cumsum<Inclusive, short>(x, dim, out);
    return;
  }
  if (x.scalar_type() == torch::kInt32) {
    cub_cumsum<Inclusive, int>(x, dim, out);
    return;
  }
  TORCH_CHECK(x.scalar_type() == torch::kInt64);
  cub_cumsum<Inclusive, long>(x, dim, out);
}

} // namespace megablocks

#undef CUB_WRAPPED_NAMESPACE