Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file tflite_layer.cpp
6 : * @date 26 October 2020
7 : * @brief This is class to encapsulate tflite as a layer of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 :
13 : #include <base_properties.h>
14 : #include <layer_context.h>
15 : #include <nntrainer_error.h>
16 : #include <nntrainer_log.h>
17 : #include <node_exporter.h>
18 : #include <tflite_layer.h>
19 : #include <util_func.h>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @brief TflModelPath property
25 : *
26 : */
27 22 : class PropsTflModelPath : public Property<std::string> {
28 : public:
29 : static constexpr const char *key = "model_path"; /**< unique key to access */
30 : using prop_tag = str_prop_tag; /**< property type */
31 :
32 : static constexpr const char ending[] = ".tflite";
33 : static constexpr unsigned int ending_len = 7;
34 : /**
35 : * @brief check is valid
36 : *
37 : * @param v value to check
38 : * @return bool true if valid
39 : */
40 : bool isValid(const std::string &v) const override;
41 : };
42 :
43 7 : bool PropsTflModelPath::isValid(const std::string &v) const {
44 7 : if (v.size() < ending_len) {
45 : return false;
46 : }
47 7 : std::string ext(v.end() - ending_len, v.end());
48 : std::for_each(ext.end() - ending_len, ext.end(),
49 49 : [](char &c) { c = ::tolower(c); });
50 :
51 : /// check if path ends with .tflite
52 14 : if (!endswith(ext, ending)) {
53 : return false;
54 : }
55 7 : std::ifstream file(v, std::ios::binary | std::ios::ate);
56 : return file.good();
57 7 : }
58 :
59 22 : TfLiteLayer::TfLiteLayer() :
60 : Layer(),
61 44 : tfl_layer_props(new PropsType(PropsTflModelPath())),
62 : interpreter(nullptr),
63 22 : model(nullptr) {}
64 :
65 44 : TfLiteLayer::~TfLiteLayer() {}
66 :
67 12 : void TfLiteLayer::setDimensions(const std::vector<int> &tensor_idx_list,
68 : std::vector<TensorDim> &dim, bool is_output) {
69 12 : unsigned int num_tensors = tensor_idx_list.size();
70 12 : dim.resize(num_tensors);
71 :
72 24 : for (unsigned int i = 0; i < num_tensors; i++) {
73 12 : unsigned int tensor_idx = tensor_idx_list[i];
74 18 : if (is_output && interpreter->tensor(tensor_idx)->type != kTfLiteFloat32)
75 : throw exception::not_supported(
76 0 : "Data type other than float32 not supported");
77 :
78 12 : unsigned int num_dims = interpreter->tensor(tensor_idx)->dims->size;
79 12 : if (num_dims > ml::train::TensorDim::MAXDIM)
80 0 : throw exception::not_supported("Number of dimensions exceed the support");
81 :
82 : /** This puts the unused dimensions to the outer dimensions */
83 28 : for (size_t dim_idx = 0; dim_idx < num_dims; dim_idx++)
84 16 : dim[i].setTensorDim(
85 : ml::train::TensorDim::MAXDIM - dim_idx - 1,
86 16 : interpreter->tensor(tensor_idx)->dims->data[num_dims - dim_idx - 1]);
87 : }
88 12 : }
89 :
90 6 : void TfLiteLayer::finalize(InitLayerContext &context) {
91 6 : tflite::ops::builtin::BuiltinOpResolver resolver;
92 18 : model = tflite::FlatBufferModel::BuildFromFile(
93 6 : std::get<PropsTflModelPath>(*tfl_layer_props).get().c_str());
94 6 : NNTR_THROW_IF(!model, std::invalid_argument)
95 : << "Failed to build tflite model";
96 :
97 6 : tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter);
98 6 : NNTR_THROW_IF(!interpreter, std::invalid_argument)
99 : << "Failed to build tflite interpreter";
100 :
101 6 : NNTR_THROW_IF(interpreter->AllocateTensors() != kTfLiteOk, std::runtime_error)
102 : << "Failed to allocate tensors!";
103 :
104 : std::vector<TensorDim> dims;
105 6 : setDimensions(interpreter->inputs(), dims, false);
106 : const std::vector<TensorDim> &input_dims = context.getInputDimensions();
107 :
108 6 : NNTR_THROW_IF(dims.size() != input_dims.size(), std::invalid_argument)
109 : << "Provided number of input dimensions mismatch";
110 :
111 12 : for (size_t idx = 0; idx < dims.size(); idx++) {
112 6 : NNTR_THROW_IF(dims[idx] != input_dims[idx], std::invalid_argument)
113 : << "Input dimensions mismatch -> " << idx << ":" << dims[idx] << " "
114 : << input_dims[idx];
115 : }
116 :
117 : std::vector<TensorDim> output_dims;
118 6 : setDimensions(interpreter->outputs(), output_dims, true);
119 6 : context.setOutputDimensions(output_dims);
120 12 : }
121 :
122 47 : void TfLiteLayer::setProperty(const std::vector<std::string> &values) {
123 47 : auto left_values = loadProperties(values, *tfl_layer_props);
124 47 : NNTR_THROW_IF(!left_values.empty(), std::invalid_argument)
125 : << "There are unparsed properties, " << left_values.front();
126 45 : }
127 :
128 175 : void TfLiteLayer::forwarding(RunLayerContext &context, bool training) {
129 175 : auto in_indices = interpreter->inputs();
130 350 : for (size_t idx = 0; idx < in_indices.size(); idx++)
131 350 : interpreter->tensor(in_indices[idx])->data.raw =
132 175 : reinterpret_cast<char *>(context.getInput(idx).getData());
133 :
134 175 : auto out_indices = interpreter->outputs();
135 350 : for (size_t idx = 0; idx < out_indices.size(); idx++) {
136 350 : interpreter->tensor(out_indices[idx])->data.raw =
137 175 : reinterpret_cast<char *>(context.getOutput(idx).getData());
138 : }
139 :
140 175 : int status = interpreter->Invoke();
141 175 : if (status != kTfLiteOk)
142 0 : throw std::runtime_error("Invoke failed");
143 :
144 : #ifdef DEBUG
145 : std::vector<TensorDim> out_tf_dim;
146 : setDimensions(interpreter->outputs(), out_tf_dim, true);
147 : if (out_tf_dim.size() != context.getNumOutputs()) {
148 : throw std::invalid_argument(
149 : "[TfliteLayer::forward] number of output dimension does not match");
150 : }
151 :
152 : for (unsigned int i = 0; i < out_tf_dim.size(); ++i) {
153 : if (context.getOutput(i).getDim() != out_tf_dim[i]) {
154 : throw std::invalid_argument(
155 : "[TfliteLayer::forward] output dimension does not match");
156 : }
157 : }
158 : #endif
159 175 : }
160 :
161 0 : void TfLiteLayer::calcDerivative(RunLayerContext &context) {
162 : throw exception::not_supported(
163 0 : "calcDerivative is not supported for tflite layer");
164 : }
165 : } /* namespace nntrainer */
|