Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file model_loader.h
6 : * @date 5 August 2020
7 : * @brief This is model loader class for the Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __MODEL_LOADER_H__
15 : #define __MODEL_LOADER_H__
16 : #ifdef __cplusplus
17 :
18 : #include <memory>
19 :
20 : #include <engine.h>
21 : #include <iniparser.h>
22 : #include <neuralnet.h>
23 :
24 : namespace nntrainer {
25 :
26 : class OptimizerWrapped;
27 :
28 : /**
29 : * @class ModelLoader
30 : * @brief Model Loader class to load model from various config files
31 : */
32 : class ModelLoader {
33 : public:
34 : /**
35 : * @brief Constructor of the model loader
36 : */
37 : ModelLoader(const Engine *ct_eng_ = &Engine::Global()) :
38 666 : ct_engine(ct_eng_), model_file_engine(nullptr) {}
39 :
40 : /**
41 : * @brief Destructor of the model loader
42 : */
43 666 : ~ModelLoader() {}
44 :
45 : /**
46 : * @brief load all properties from context
47 : * @param[in/out] model model to be loaded
48 : */
49 : int loadFromContext(NeuralNetwork &model);
50 :
51 : /**
52 : * @brief load all of model and dataset from given config file
53 : * @param[in] config config file path
54 : * @param[in/out] model model to be loaded
55 : */
56 : int loadFromConfig(std::string config, NeuralNetwork &model);
57 :
58 : private:
59 : /**
60 : * @brief load all of model from given config file
61 : * @param[in] config config file path
62 : * @param[in/out] model model to be loaded
63 : * @param[in] bare_layers load only the layers as backbone if enabled
64 : */
65 : int loadFromConfig(std::string config, NeuralNetwork &model,
66 : bool bare_layers);
67 :
68 : /**
69 : * @brief load all of model and dataset from ini
70 : * @param[in] ini_file config file path
71 : * @param[in/out] model model to be loaded
72 : */
73 : int loadFromIni(std::string ini_file, NeuralNetwork &model, bool bare_layers);
74 :
75 : /**
76 : * @brief load model from ONNX file
77 : * @param[in] onnx_file onnx config file path
78 : * @param[in/out] model model to be loaded
79 : */
80 : int loadFromONNX(std::string onnx_file, NeuralNetwork &model);
81 :
82 : /**
83 : * @brief load dataset config from ini
84 : * @param[in] ini dictionary containing the config
85 : * @param[in] model model to be loaded
86 : */
87 : int loadDatasetConfigIni(dictionary *ini, NeuralNetwork &model);
88 :
89 : /**
90 : * @brief load model config from ini
91 : * @param[in] ini dictionary containing the config
92 : * @param[in/out] model model to be loaded
93 : */
94 : int loadModelConfigIni(dictionary *ini, NeuralNetwork &model);
95 :
96 : /**
97 : * @brief load optimizer config from ini
98 : * @param[in] ini dictionary containing the config
99 : * @param[in/out] model model to be loaded
100 : */
101 : int loadOptimizerConfigIni(dictionary *ini, NeuralNetwork &model);
102 :
103 : /**
104 : * @brief load learning rate scheduler config from ini
105 : * @param[in] ini dictionary containing the config
106 : * @param[in/out] optimizer to contain the lr scheduler
107 : */
108 : int loadLearningRateSchedulerConfigIni(
109 : dictionary *ini, std::shared_ptr<ml::train::Optimizer> &optimizer);
110 :
111 : /**
112 : * @brief Check if the file extension is the given @a ext
113 : * @param[in] filename full name of the file
114 : * @param[in] ext extension to match with
115 : * @retval true if @a ext, else false
116 : */
117 : static bool fileExt(const std::string &filename, const std::string &ext);
118 :
119 : /**
120 : * @brief Check if the file extension is ini
121 : * @param[in] filename full name of the file
122 : * @retval true if ini, else false
123 : */
124 : static bool isIniFile(const std::string &filename);
125 :
126 : /**
127 : * @brief Check if the file extension is tflite
128 : * @param[in] filename full name of the file
129 : * @retval true if tflite, else false
130 : */
131 : static bool isTfLiteFile(const std::string &filename);
132 :
133 : /**
134 : * @brief Check if the file extension is ONNX
135 : * @param[in] filename full name of the ONNX file
136 : * @retval true if ONNX, else false
137 : */
138 : static bool isONNXFile(const std::string &filename);
139 :
140 : /**
141 : * @brief resolvePath to absolute path written in a model description
142 : *
143 : * @note if path is absolute path, return path.
144 : * if app_context has working directory set, resolve from app_context
145 : * if not, resolve path assuming model_path is the current directory.
146 : * The behavior relies on the semantics of getWorkingPath();
147 : * @param path path to resolve
148 : * @return const std::string resolved path.
149 : */
150 3115 : const std::string resolvePath(const std::string &path) {
151 3115 : auto path_ = ct_engine->getWorkingPath(path);
152 6230 : return model_file_engine->getWorkingPath(path_);
153 : }
154 :
155 : /**
156 : * @brief parse all the properties for a given section
157 : * @param[in] ini dictionary containing the config
158 : * @param[in] section_name name of the section for properties to parse
159 : * @param[in] filter_props the properties to be filtered out
160 : */
161 : std::vector<std::string>
162 : parseProperties(dictionary *ini, const std::string §ion_name,
163 : const std::vector<std::string> &filter_props = {});
164 :
165 : const char *unknown = "Unknown";
166 :
167 : const Engine *ct_engine = nullptr;
168 : std::unique_ptr<Engine> model_file_engine;
169 : /**< model_file specific context which is
170 : // referred to as if app_context cannot
171 : // resolve some given configuration */
172 : };
173 :
174 : } /* namespace nntrainer */
175 :
176 : #endif /* __cplusplus */
177 : #endif /* __MODEL_LOADER_H__ */
|