Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file tflite_layer.h
6 : * @date 3 November 2020
7 : * @brief This is class to encapsulate tflite as a layer of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __TENSORFLOW_LITE_H__
15 : #define __TENSORFLOW_LITE_H__
16 : #ifdef __cplusplus
17 :
18 : #include <layer_devel.h>
19 : #include <vector>
20 :
21 : #include <tensorflow/lite/interpreter.h>
22 : #include <tensorflow/lite/kernels/register.h>
23 : #include <tensorflow/lite/model.h>
24 :
25 : namespace ml::train {
26 : class TensorDim;
27 : }
28 :
29 : namespace nntrainer {
30 :
31 : class PropsTflModelPath;
32 :
33 : /**
34 : * @class TfLiteLayer
35 : * @brief Tensorflow Lite layer
36 : */
37 : class TfLiteLayer : public Layer {
38 : public:
39 : /**
40 : * @brief Constructor of NNStreamer Layer
41 : */
42 : TfLiteLayer();
43 :
44 : /**
45 : * @brief Destructor of NNStreamer Layer
46 : */
47 : ~TfLiteLayer();
48 :
49 : /**
50 : * @copydoc Layer::finalize(InitLayerContext &context)
51 : */
52 : void finalize(InitLayerContext &context) override;
53 :
54 : /**
55 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
56 : */
57 : void forwarding(RunLayerContext &context, bool training) override;
58 :
59 : /**
60 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
61 : */
62 : void calcDerivative(RunLayerContext &context) override;
63 :
64 : /**
65 : * @copydoc Layer::getType()
66 : */
67 68 : const std::string getType() const override { return TfLiteLayer::type; };
68 :
69 : /**
70 : * @copydoc Layer::supportBackwarding()
71 : */
72 2 : bool supportBackwarding() const override { return false; }
73 :
74 : /**
75 : * @copydoc Layer::setProperty(const PropertyType type, const std::string
76 : * &value)
77 : */
78 : void setProperty(const std::vector<std::string> &values) override;
79 :
80 : static constexpr const char *type = "backbone_tflite";
81 :
82 : private:
83 : using PropsType = std::tuple<PropsTflModelPath>;
84 : std::unique_ptr<PropsType> tfl_layer_props;
85 : std::unique_ptr<tflite::Interpreter> interpreter;
86 : std::unique_ptr<tflite::FlatBufferModel> model;
87 :
88 : /**
89 : * @brief Set the Dimensions object
90 : *
91 : * @param tensor_idx_list tensor index list
92 : * @param dim dimension
93 : * @param is_output check if output
94 : */
95 : void setDimensions(const std::vector<int> &tensor_idx_list,
96 : std::vector<ml::train::TensorDim> &dim, bool is_output);
97 : };
98 :
99 : } // namespace nntrainer
100 :
101 : #endif /* __cplusplus */
102 : #endif /* __TENSORFLOW_LITE_H__ */
|