Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file centroid_knn.h
6 : * @date 09 Jan 2021
7 : * @details This file contains the simple nearest neighbor layer, this layer
8 : * takes centroid and calculate l2 distance
9 : * @see https://github.com/nnstreamer/nntrainer
10 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : *
13 : */
14 : #ifndef __CENTROID_KNN_H__
15 : #define __CENTROID_KNN_H__
16 : #include <string>
17 :
18 : #include <common_properties.h>
19 : #include <layer_devel.h>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @brief Centroid KNN layer which takes centroid and do k-nearest neighbor
25 : * classification
26 : */
27 : class CentroidKNN : public Layer {
28 : public:
29 : /**
30 : * @brief Construct a new NearestNeighbors Layer object that does elementwise
31 : * subtraction from mean feature vector
32 : */
33 : CentroidKNN();
34 :
35 : /**
36 : * @brief Move constructor.
37 : * @param[in] CentroidKNN &&
38 : */
39 : CentroidKNN(CentroidKNN &&rhs) noexcept = default;
40 :
41 : /**
42 : * @brief Move assignment operator.
43 : * @parma[in] rhs CentroidKNN to be moved.
44 : */
45 : CentroidKNN &operator=(CentroidKNN &&rhs) noexcept = default;
46 :
47 : /**
48 : * @brief Destroy the NearestNeighbors Layer object
49 : *
50 : */
51 : ~CentroidKNN();
52 :
53 : /**
54 : * @copydoc Layer::requireLabel()
55 : */
56 280 : bool requireLabel() const override { return true; }
57 :
58 : /**
59 : * @copydoc Layer::finalize(InitLayerContext &context)
60 : */
61 : void finalize(nntrainer::InitLayerContext &context) override;
62 :
63 : /**
64 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
65 : */
66 : void forwarding(nntrainer::RunLayerContext &context, bool training) override;
67 :
68 : /**
69 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
70 : */
71 : void calcDerivative(nntrainer::RunLayerContext &context) override;
72 :
73 : /**
74 : * @copydoc bool supportBackwarding() const
75 : */
76 1 : bool supportBackwarding() const override { return false; };
77 :
78 : /**
79 : * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
80 : */
81 : void exportTo(nntrainer::Exporter &exporter,
82 : const ml::train::ExportMethods &method) const override;
83 :
84 : /**
85 : * @copydoc Layer::getType()
86 : */
87 25 : const std::string getType() const override { return CentroidKNN::type; };
88 :
89 : /**
90 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
91 : */
92 : void setProperty(const std::vector<std::string> &values) override;
93 :
94 : static constexpr const char *type = "centroid_knn";
95 :
96 : private:
97 : std::tuple<props::NumClass> centroid_knn_props;
98 : std::array<unsigned int, 2> weight_idx; /**< indices of the weights */
99 : };
100 : } // namespace nntrainer
101 :
102 : #endif /** __CENTROID_KNN_H__ */
|