Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file split_layer.h
6 : * @date 21 May 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Split Layer Class for Neural Network
11 : *
12 : * @todo Add support for uneven splits. For now, this can
13 : * be acheived with combination of split and concat layers.
14 : */
15 :
16 : #ifndef __SPLIT_LAYER_H__
17 : #define __SPLIT_LAYER_H__
18 : #ifdef __cplusplus
19 :
20 : #include <common_properties.h>
21 : #include <layer_devel.h>
22 : #include <tensor_dim.h>
23 :
24 : namespace nntrainer {
25 :
26 : /**
27 : * @class Split Layer
28 : * @brief Split Layer
29 : */
30 : class SplitLayer : public Layer {
31 : public:
32 : /**
33 : * @brief Constructor of Split Layer
34 : */
35 : SplitLayer();
36 :
37 : /**
38 : * @brief Destructor of Split Layer
39 : */
40 74 : ~SplitLayer() = default;
41 :
42 : /**
43 : * @brief Move constructor of SplitLayer.
44 : * @param[in] SplitLayer &&
45 : */
46 : SplitLayer(SplitLayer &&rhs) noexcept = default;
47 :
48 : /**
49 : * @brief Move assignment operator.
50 : * @parma[in] rhs SplitLayer to be moved.
51 : */
52 : SplitLayer &operator=(SplitLayer &&rhs) = default;
53 :
54 : /**
55 : * @copydoc Layer::finalize(InitLayerContext &context)
56 : */
57 : void finalize(InitLayerContext &context) override;
58 :
59 : /**
60 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
61 : */
62 : void forwarding(RunLayerContext &context, bool training) override;
63 :
64 : /**
65 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
66 : */
67 : void calcDerivative(RunLayerContext &context) override;
68 :
69 : /**
70 : * @copydoc bool supportBackwarding() const
71 : */
72 82 : bool supportBackwarding() const override { return true; };
73 :
74 : /**
75 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
76 : * method)
77 : */
78 : void exportTo(Exporter &exporter,
79 : const ml::train::ExportMethods &method) const override;
80 :
81 : /**
82 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
83 : */
84 : void setProperty(const std::vector<std::string> &values) override;
85 :
86 : /**
87 : * @copydoc Layer::getType()
88 : */
89 818 : const std::string getType() const override { return SplitLayer::type; };
90 :
91 : static constexpr const char *type = "split";
92 :
93 : /**
94 : * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
95 : */
96 33 : void setBatch(RunLayerContext &context, unsigned int batch) override {
97 33 : setBatch(batch);
98 33 : }
99 :
100 : private:
101 : unsigned int leading_helper_dim; /**< batch dimension of helper dimension not
102 : containing the actual batch */
103 : TensorDim input_reshape_helper; /** helper dimension to reshape input */
104 : TensorDim output_reshape_helper; /** helper dimension to reshape outputs */
105 : std::tuple<props::SplitDimension, props::SplitNumber> split_props;
106 :
107 : /**
108 : * @brief set batch for the internal variables
109 : *
110 : * @param batch update batch size
111 : */
112 70 : void setBatch(unsigned int batch) {
113 70 : input_reshape_helper.batch(batch * leading_helper_dim);
114 70 : output_reshape_helper.batch(batch * leading_helper_dim);
115 70 : }
116 : };
117 :
118 : } // namespace nntrainer
119 :
120 : #endif /* __cplusplus */
121 : #endif /* __SPLIT_LAYER_H__ */
|