Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Donghoon Kang <dhkang01@snu.ac.kr>
4 : *
5 : * @file channel_shuffle.cpp
6 : * @date 23 April 2025
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Donghoon Kang <dhkang01@snu.ac.kr>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Channel Shuffle Layer Class for Neural Network
11 : *
12 : */
13 :
14 : #include <algorithm>
15 : #include <cstring>
16 : #include <limits>
17 : #include <string>
18 :
19 : #include <channel_shuffle.h>
20 : #include <cpu_backend.h>
21 : #include <layer_context.h>
22 : #include <lazy_tensor.h>
23 : #include <nntr_threads.h>
24 : #include <nntrainer_error.h>
25 : #include <nntrainer_log.h>
26 : #include <node_exporter.h>
27 : #include <profiler.h>
28 : #include <tensor_dim.h>
29 : #include <thread>
30 : #include <util_func.h>
31 :
32 : namespace nntrainer {
33 :
34 : static constexpr size_t SINGLE_INOUT_IDX = 0;
35 :
36 : /**
37 : * @brief Helper function to perform channel shuffle transpose operation
38 : * @param input Input tensor to transpose
39 : * @param output Output tensor to store transposed result
40 : * @param num_groups Number of groups for channel shuffle
41 : * @param channels_per_group Number of channels per group
42 : */
43 9 : static void channel_shuffle_transpose(const Tensor &input, Tensor &output) {
44 9 : const TensorDim &dim = input.getDim();
45 :
46 9 : if (dim.getFormat() == ml::train::TensorDim::Format::NHWC) {
47 : // For NHWC format: [N,HW,G,C/G] -> [N,HW,C/G,G]
48 0 : for (unsigned int n = 0; n < dim.batch(); ++n) {
49 0 : for (unsigned int hw = 0; hw < dim.height(); ++hw) {
50 0 : for (unsigned int g = 0; g < dim.width(); ++g) {
51 0 : for (unsigned int c = 0; c < dim.channel(); ++c) {
52 0 : float val = input.getValue<float>(n, hw, g, c);
53 0 : output.setValue(n, hw, c, g, val);
54 : }
55 : }
56 : }
57 : }
58 : } else {
59 : // For NCHW format: [N,G,C/G,H*W] -> [N,C/G,G,H*W]
60 18 : for (unsigned int n = 0; n < dim.batch(); ++n) {
61 39 : for (unsigned int g = 0; g < dim.channel(); ++g) {
62 102 : for (unsigned int c = 0; c < dim.height(); ++c) {
63 1224 : for (unsigned int hw = 0; hw < dim.width(); ++hw) {
64 1152 : float val = input.getValue<float>(n, g, c, hw);
65 1152 : output.setValue(n, c, g, hw, val);
66 : }
67 : }
68 : }
69 : }
70 : }
71 9 : }
72 :
73 4 : ChannelShuffle::ChannelShuffle() :
74 4 : channel_shuffle_props(props::SplitNumber()) {}
75 :
76 4 : void ChannelShuffle::finalize(InitLayerContext &context) {
77 4 : NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
78 : << "Channel Shuffle layer takes only one input";
79 :
80 : const TensorDim &in_dim = context.getInputDimensions()[0];
81 :
82 : unsigned int num_groups =
83 4 : std::get<props::SplitNumber>(channel_shuffle_props).get();
84 :
85 : // If split_number is 0, find the smallest divisor of channel count that is
86 : // greater than 1 and less than channel
87 4 : if (num_groups == 0) {
88 0 : unsigned int channel_count = in_dim.channel();
89 0 : for (unsigned int i = 2; i < channel_count; i++) {
90 0 : if (channel_count % i == 0) {
91 0 : num_groups = i;
92 0 : break;
93 : }
94 : }
95 :
96 0 : NNTR_THROW_IF(num_groups == 0, std::invalid_argument)
97 : << "Input split_number is 0, and channel count is prime number";
98 :
99 0 : std::get<props::SplitNumber>(channel_shuffle_props).set(num_groups);
100 : }
101 :
102 : // Validate split_number
103 4 : NNTR_THROW_IF(num_groups <= 1, std::invalid_argument)
104 : << "Number of groups must be greater than 1";
105 :
106 4 : NNTR_THROW_IF(num_groups >= in_dim.channel(), std::invalid_argument)
107 : << "Number of groups must be less than number of channels";
108 :
109 4 : NNTR_THROW_IF(in_dim.channel() % num_groups != 0, std::invalid_argument)
110 : << "Number of channels must be divisible by number of groups";
111 :
112 : // Output dimensions are same as input dimensions
113 4 : context.setOutputDimensions({in_dim});
114 4 : }
115 :
116 6 : void ChannelShuffle::forwarding(RunLayerContext &context, bool training) {
117 : /**
118 : * Channel Shuffle Operation:
119 : *
120 : * Input Tensor: [N, C, H, W] where:
121 : * - N: batch size
122 : * - C: number of channels
123 : * - H: height
124 : * - W: width
125 : *
126 : * Let's say we have:
127 : * - C = 12 channels
128 : * - G = 3 groups (specified by SplitNumber)
129 : *
130 : * Step 1: Reshape into groups
131 : * [N, C, H, W] -> [N, G, C/G, H, W]
132 : * Example
133 : * Before: [N, 12, H, W]
134 : * After: [N, 3, 4, H, W]
135 : *
136 : * Step 2: Transpose groups
137 : * [N, G, C/G, H, W] -> [N, C/G, G, H, W]
138 : * Example
139 : * Before: [N, 3, 4, H, W]
140 : * After: [N, 4, 3, H, W]
141 : *
142 : * Step 3: Reshape back
143 : * [N, C/G, G, H, W] -> [N, C, H, W]
144 : * Example
145 : * Before: [N, 4, 3, H, W]
146 : * After: [N, 12, H, W]
147 : *
148 : * Visualization:
149 : * Original: [1,2,3,4,5,6,7,8,9,10,11,12]
150 : * After Step1: [[1, 2, 3, 4],
151 : * [5, 6, 7, 8],
152 : * [9,10,11,12]]
153 : * After Step2: [[1, 5, 9],
154 : * [2, 6,10],
155 : * [3, 7,11],
156 : * [4, 8,12]]
157 : * After Step3: [1,5,9,2,6,10,3,7,11,4,8,12]
158 : */
159 6 : Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
160 6 : Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
161 :
162 6 : unsigned int num_groups = std::get<props::SplitNumber>(channel_shuffle_props);
163 6 : unsigned int channels_per_group = input_.channel() / num_groups;
164 :
165 : // Calculate dimensions once before parallel section
166 :
167 6 : TensorDim in_dim = input_.getDim();
168 6 : TensorDim group_dim;
169 6 : TensorDim transposed_dim;
170 6 : TensorDim original_dim;
171 :
172 6 : if (in_dim.getFormat() == ml::train::TensorDim::Format::NHWC) {
173 0 : group_dim = input_.getDim(); // [N,H,W,C]
174 0 : group_dim.height(group_dim.height() * group_dim.width());
175 0 : group_dim.width(num_groups);
176 0 : group_dim.channel(channels_per_group);
177 0 : group_dim.batch(1); // [1,HW,G,C/G]
178 :
179 0 : transposed_dim = group_dim; // [1,HW,G,C/G]
180 0 : transposed_dim.channel(num_groups);
181 0 : transposed_dim.width(channels_per_group); // [1,HW,C/G,G]
182 :
183 : // Original dimension for final reshape
184 0 : original_dim = hidden_.getDim(); // [1,H,W,C]
185 0 : original_dim.batch(1); // For batch slice
186 : }
187 :
188 6 : else if (in_dim.getFormat() == ml::train::TensorDim::Format::NCHW) {
189 6 : group_dim = input_.getDim(); // [N, C, H, W]
190 6 : group_dim.width(group_dim.width() * group_dim.height());
191 6 : group_dim.height(channels_per_group);
192 6 : group_dim.channel(num_groups); // [N, G, C/G, H*W]
193 6 : group_dim.batch(1); // For batch slice
194 :
195 6 : transposed_dim = group_dim; // [1, G, C/G, H*W]
196 6 : transposed_dim.channel(channels_per_group);
197 6 : transposed_dim.height(num_groups); // [1, C/G, G, H*W]
198 :
199 6 : original_dim = hidden_.getDim(); // [N, C, H, W]
200 6 : original_dim.batch(1); // For batch slice
201 : }
202 :
203 : else {
204 0 : NNTR_THROW_IF(true, std::invalid_argument)
205 : << "Channel Shuffle layer only supports NHWC and NCHW format";
206 : }
207 :
208 6 : auto forwarding_job = [&](unsigned int s, unsigned int e, unsigned int pid,
209 : void *user_data) {
210 12 : for (unsigned int b = s; b < e; ++b) {
211 6 : Tensor out = hidden_.getBatchSlice(b, 1);
212 6 : Tensor in_sub = input_.getBatchSlice(b, 1);
213 :
214 : // Step 1: Reshape into groups
215 : // [1, C, H, W] -> [1, G, C/G, H*W]
216 : // [1, H, W, C] -> [1, HW, G, C/G] in NHWC format
217 6 : in_sub.reshape(group_dim);
218 :
219 : // Step 2: Transpose groups
220 : // [1, G, C/G, H*W] -> [1, C/G, G, H*W]
221 : // [1, HW, G, C/G] -> [1, HW, C/G, G] in NHWC format
222 6 : out.reshape(transposed_dim);
223 6 : channel_shuffle_transpose(in_sub, out);
224 :
225 : // Step 3: Reshape back to original dimensions
226 : // [1, C/G, G, H*W] -> [1, C, H, W]
227 : // [1, HW, C/G, G] -> [1, H, W, C] in NHWC format
228 6 : out.reshape(original_dim);
229 6 : }
230 6 : };
231 :
232 12 : auto workers = ParallelBatch(forwarding_job, input_.batch(), nullptr);
233 :
234 6 : if (workers.getNumWorkers() > 1) {
235 0 : workers.run();
236 : } else {
237 6 : forwarding_job(0, input_.batch(), 0, nullptr);
238 : }
239 6 : }
240 :
241 3 : void ChannelShuffle::calcDerivative(RunLayerContext &context) {
242 3 : const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
243 3 : Tensor &input_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
244 :
245 3 : unsigned int num_groups = std::get<props::SplitNumber>(channel_shuffle_props);
246 3 : unsigned int channels_per_group = derivative.channel() / num_groups;
247 :
248 : // Calculate dimensions once before parallel section
249 3 : TensorDim group_dim = derivative.getDim();
250 3 : TensorDim transposed_dim;
251 3 : TensorDim original_dim;
252 :
253 3 : if (derivative.getFormat() == ml::train::TensorDim::Format::NHWC) {
254 : // First reshape to [N,HW,C/G,G]
255 0 : group_dim.height(group_dim.height() * group_dim.width());
256 0 : group_dim.width(channels_per_group);
257 0 : group_dim.channel(num_groups);
258 0 : group_dim.batch(1); // [1,HW,C/G,G]
259 :
260 : // For transposed dimension, we'll reshape to [1,HW,G,C/G]
261 0 : transposed_dim = group_dim; // [1,HW,C/G,G]
262 0 : transposed_dim.channel(channels_per_group);
263 0 : transposed_dim.width(num_groups); // [1,HW,G,C/G]
264 :
265 : // Original dimension for final reshape
266 0 : original_dim = input_derivative.getDim(); // [1,H,W,C]
267 0 : original_dim.batch(1); // For batch slice
268 : } else {
269 3 : group_dim.width(group_dim.width() * group_dim.height());
270 3 : group_dim.height(num_groups);
271 3 : group_dim.channel(channels_per_group); // [N,C/G,G,H*W]
272 3 : group_dim.batch(1); // For batch slice
273 :
274 3 : transposed_dim = group_dim; // [1,C/G,G,H*W]
275 3 : transposed_dim.channel(num_groups);
276 3 : transposed_dim.height(channels_per_group); // [1,G,C/G,H*W]
277 :
278 3 : original_dim = input_derivative.getDim(); // [N,C,H,W]
279 3 : original_dim.batch(1); // For batch slice
280 : }
281 :
282 3 : auto compute_derivative = [&](unsigned int s, unsigned int e,
283 : unsigned int pid, void *user_data) {
284 6 : for (unsigned int b = s; b < e; ++b) {
285 3 : Tensor deriv_sub = derivative.getBatchSlice(b, 1);
286 3 : Tensor in_deriv_sub = input_derivative.getBatchSlice(b, 1);
287 :
288 : // Step 1: Reshape into groups
289 : // [1, C, H, W] -> [1, C/G, G, H*W]
290 : // [1, H, W, C] -> [1, HW, C/G, G] in NHWC format
291 3 : deriv_sub.reshape(group_dim);
292 :
293 : // Step 2: Transpose groups (inverse of forward operation)
294 : // [1, C/G, G, H*W] -> [1, G, C/G, H*W]
295 : // [1, HW, C/G, G] -> [1, HW, G, C/G] in NHWC format
296 3 : in_deriv_sub.reshape(transposed_dim);
297 3 : channel_shuffle_transpose(deriv_sub, in_deriv_sub);
298 :
299 : // Step 3: Reshape back to original dimensions
300 : // [1, G, C/G, H*W] -> [1, C, H, W]
301 : // [1, HW, G, C/G] -> [1, H, W, C] in NHWC format
302 3 : in_deriv_sub.reshape(original_dim);
303 3 : }
304 3 : };
305 :
306 6 : auto workers = ParallelBatch(compute_derivative, derivative.batch(), nullptr);
307 :
308 3 : if (workers.getNumWorkers() > 1) {
309 0 : workers.run();
310 : } else {
311 3 : compute_derivative(0, derivative.batch(), 0, nullptr);
312 : }
313 3 : }
314 :
315 0 : void ChannelShuffle::calcGradient(RunLayerContext &context) {
316 : // Channel Shuffle layer has no weights to update
317 : // No gradient calculation needed
318 0 : }
319 :
320 2 : void ChannelShuffle::exportTo(Exporter &exporter,
321 : const ml::train::ExportMethods &method) const {
322 2 : LayerImpl::exportTo(exporter, method);
323 2 : exporter.saveResult(channel_shuffle_props, method, this);
324 2 : }
325 :
326 22 : void ChannelShuffle::setProperty(const std::vector<std::string> &values) {
327 22 : auto remain_props = loadProperties(values, channel_shuffle_props);
328 22 : LayerImpl::setProperty(remain_props);
329 22 : }
330 :
331 : } // namespace nntrainer
|