Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Donghoon Kang <dhkang01@snu.ac.kr>
4 : *
5 : * @file channel_shuffle.h
6 : * @date 23 April 2025
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Donghoon Kang <dhkang01@snu.ac.kr>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Channel Shuffle Layer Class for Neural Network
11 : *
12 : */
13 :
14 : #ifndef __CHANNEL_SHUFFLE_H_
15 : #define __CHANNEL_SHUFFLE_H_
16 : #ifdef __cplusplus
17 :
18 : #include <memory.h>
19 :
20 : #include <common_properties.h>
21 : #include <layer_devel.h>
22 : #include <layer_impl.h>
23 : #include <tensor_dim.h>
24 :
25 : namespace nntrainer {
26 :
27 : /**
28 : * @class ChannelShuffle
29 : * @brief Channel Shuffle Layer
30 : */
31 : class ChannelShuffle : public LayerImpl {
32 : public:
33 : /**
34 : * @brief Constructor of Channel Shuffle Layer
35 : */
36 : ChannelShuffle();
37 :
38 : /**
39 : * @brief Destructor of Channel Shuffle Layer
40 : */
41 8 : ~ChannelShuffle() = default;
42 :
43 : /**
44 : * @brief Move constructor of Channel Shuffle Layer.
45 : * @param[in] ChannelShuffle &&
46 : */
47 : ChannelShuffle(ChannelShuffle &&rhs) noexcept = default;
48 :
49 : /**
50 : * @brief Move assignment operator.
51 : * @parma[in] rhs ChannelShuffle to be moved.
52 : */
53 : ChannelShuffle &operator=(ChannelShuffle &&rhs) = default;
54 :
55 : /**
56 : * @copydoc Layer::finalize(InitLayerContext &context)
57 : */
58 : void finalize(InitLayerContext &context) override;
59 :
60 : /**
61 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
62 : */
63 : void forwarding(RunLayerContext &context, bool training) override;
64 :
65 : /**
66 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
67 : */
68 : void calcDerivative(RunLayerContext &context) override;
69 :
70 : /**
71 : * @copydoc Layer::calcGradient(RunLayerContext &context)
72 : */
73 : void calcGradient(RunLayerContext &context) override;
74 :
75 : /**
76 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
77 : * method)
78 : */
79 : void exportTo(Exporter &exporter,
80 : const ml::train::ExportMethods &method) const override;
81 :
82 : /**
83 : * @copydoc Layer::getType()
84 : */
85 88 : const std::string getType() const override { return ChannelShuffle::type; };
86 :
87 : /**
88 : * @copydoc Layer::supportBackwarding()
89 : */
90 10 : bool supportBackwarding() const override { return true; };
91 :
92 : /**
93 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
94 : */
95 : void setProperty(const std::vector<std::string> &values) override;
96 :
97 : static constexpr const char *type = "channel_shuffle";
98 :
99 : private:
100 : std::tuple<props::SplitNumber> channel_shuffle_props;
101 : };
102 : } // namespace nntrainer
103 :
104 : #endif /* __cplusplus */
105 : #endif /* __CHANNEL_SHUFFLE_H__ */
|