Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 SeungBaek Hong <sb92.hong@samsung.com>
4 : *
5 : * @file gather_layer.h
6 : * @date 02 April 2025
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author SeungBaek Hong <sb92.hong@samsung.com>
9 : * @bug It's not implemented operation yet. Just a draft for compilation.
10 : * @brief This is gather layer class (operation layer)
11 : */
12 :
13 : #ifndef __GATHER_LAYER_H__
14 : #define __GATHER_LAYER_H__
15 : #ifdef __cplusplus
16 :
17 : #include <common_properties.h>
18 : #include <layer_devel.h>
19 : #include <operation_layer.h>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @class Gather Layer
25 : * @brief Gather Layer
26 : */
27 : class GatherLayer : public BinaryOperationLayer {
28 : public:
29 : /**
30 : * @brief Constructor of Gather Layer
31 : */
32 0 : GatherLayer() : support_backwarding(true) {}
33 :
34 : /**
35 : * @brief Destructor of Gather Layer
36 : */
37 0 : ~GatherLayer(){};
38 :
39 : /**
40 : * @brief Move constructor of Gather Layer.
41 : * @param[in] GatherLayer &&
42 : */
43 : GatherLayer(GatherLayer &&rhs) noexcept = default;
44 :
45 : /**
46 : * @brief Move assignment operator.
47 : * @parma[in] rhs GatherLayer to be moved.
48 : */
49 : GatherLayer &operator=(GatherLayer &&rhs) = default;
50 :
51 : /**
52 : * @copydoc Layer::finalize(InitLayerContext &context)
53 : */
54 : void finalize(InitLayerContext &context) final;
55 :
56 : /**
57 : * @brief forwarding operation for gather
58 : *
59 : * @param input tensor to be gathered from
60 : * @param indices tensor containing the indices of elements to gather
61 : * @param hidden tensor to store the result value
62 : */
63 : void forwarding_operation(const Tensor &input, const Tensor &indices,
64 : Tensor &hidden) final;
65 :
66 : /**
67 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
68 : */
69 : void calcDerivative(RunLayerContext &context) final;
70 :
71 : /**
72 : * @copydoc bool supportBackwarding() const
73 : */
74 0 : bool supportBackwarding() const final { return support_backwarding; };
75 :
76 : /**
77 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
78 : * method)
79 : */
80 0 : void exportTo(Exporter &exporter,
81 0 : const ml::train::ExportMethods &method) const final {}
82 :
83 : /**
84 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
85 : */
86 : void setProperty(const std::vector<std::string> &values) final;
87 :
88 : /**
89 : * @copydoc Layer::getType()
90 : */
91 0 : const std::string getType() const final { return GatherLayer::type; };
92 :
93 : std::tuple<props::Print, props::Axis> gather_props;
94 : unsigned int axis = 0;
95 : bool support_backwarding;
96 :
97 : inline static const std::string type = "gather";
98 : };
99 :
100 : } // namespace nntrainer
101 :
102 : #endif /* __cplusplus */
103 : #endif /* __GATHER_LAYER_H__ */
|