Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file app_context.h
6 : * @date 10 November 2020
7 : * @brief This file contains app context related functions and classes that
8 : * manages the global configuration of the current environment
9 : * @see https://github.com/nnstreamer/nntrainer
10 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : *
13 : */
14 :
15 : #ifndef __APP_CONTEXT_H__
16 : #define __APP_CONTEXT_H__
17 :
18 : #include <algorithm>
19 : #include <functional>
20 : #include <memory>
21 : #include <mutex>
22 : #include <sstream>
23 : #include <stdexcept>
24 : #include <string>
25 : #include <type_traits>
26 : #include <unordered_map>
27 : #include <vector>
28 :
29 : #include <layer.h>
30 : #include <layer_devel.h>
31 : #include <optimizer.h>
32 : #include <optimizer_devel.h>
33 :
34 : #include <context.h>
35 : #include <mem_allocator.h>
36 : #include <nntrainer_error.h>
37 :
38 : #include "singleton.h"
39 :
40 : namespace nntrainer {
41 :
42 : extern std::mutex factory_mutex;
43 : namespace {} // namespace
44 :
45 : /**
46 : * @class AppContext contains user-dependent configuration
47 : * @brief App
48 : */
49 : class AppContext : public Context, public Singleton<AppContext> {
50 : public:
51 : /**
52 : * @brief Default constructor
53 : */
54 114 : AppContext() : Context(std::make_shared<ContextData>()) {}
55 :
56 : /**
57 : * @brief Default destructor
58 : */
59 77 : ~AppContext() override = default;
60 :
61 : /**
62 : * @brief Set Working Directory for a relative path. working directory is set
63 : * canonically
64 : * @param[in] base base directory
65 : * @throw std::invalid_argument if path is not valid for current system
66 : */
67 : void setWorkingDirectory(const std::string &base);
68 :
69 : /**
70 : * @brief unset working directory
71 : *
72 : */
73 : void unsetWorkingDirectory() { working_path_base = ""; }
74 :
75 : /**
76 : * @brief query if the appcontext has working directory set
77 : *
78 : * @retval true working path base is set
79 : * @retval false working path base is not set
80 : */
81 : bool hasWorkingDirectory() { return !working_path_base.empty(); }
82 :
83 : /**
84 : * @brief register a layer factory from a shared library
85 : * plugin must have **extern "C" LayerPluggable *ml_train_layer_pluggable**
86 : * defined else error
87 : *
88 : * @param library_path a file name of the library
89 : * @param base_path base path to make a full path (optional)
90 : * @return int integer key to create the layer
91 : * @throws std::invalid_parameter if library_path is invalid or library is
92 : * invalid
93 : */
94 : int registerLayer(const std::string &library_path,
95 : const std::string &base_path = "");
96 :
97 : /**
98 : * @brief register a optimizer factory from a shared library
99 : * plugin must have **extern "C" OptimizerPluggable
100 : * *ml_train_optimizer_pluggable** defined else error
101 : *
102 : * @param library_path a file name of the library
103 : * @param base_path base path to make a full path (optional)
104 : * @return int integer key to create the optimizer
105 : * @throws std::invalid_parameter if library_path is invalid or library is
106 : * invalid
107 : */
108 : int registerOptimizer(const std::string &library_path,
109 : const std::string &base_path = "");
110 :
111 : /**
112 : * @brief register pluggables from a directory.
113 : * @note if you have a clashing type with already registered pluggable, it
114 : * will throw from `registerFactory` function
115 : *
116 : * @param base_path a directory path to search pluggables's
117 : * @return std::vector<int> list of integer key to create a pluggable
118 : */
119 : std::vector<int> registerPluggableFromDirectory(const std::string &base_path);
120 :
121 : /**
122 : * @brief Get Working Path from a relative or representation of a path
123 : * starting from @a working_path_base.
124 : * @param[in] path to make full path
125 : * @return If absolute path is given, returns @a path
126 : * If relative path is given and working_path_base is not set, return
127 : * relative path.
128 : * If relative path is given and working_path_base has set, return absolute
129 : * path from current working directory
130 : */
131 : const std::string getWorkingPath(const std::string &path = "");
132 :
133 : /**
134 : * @brief Get memory fsu file path from configuration file
135 : * @return memory fsu path.
136 : * If memory fsu path is not presented in configuration file, it returns
137 : * empty string
138 : */
139 : const std::vector<std::string> getProperties(void);
140 :
141 : /**
142 : * @brief Factory register function, use this function to register custom
143 : * object
144 : *
145 : * @tparam T object to create. Currently Optimizer, Layer is supported
146 : * @param factory factory function that creates std::unique_ptr<T>
147 : * @param key key to access the factory, if key is empty, try to find key by
148 : * calling factory({})->getType();
149 : * @param int_key key to access the factory by integer, if it is -1(default),
150 : * the function automatically unsigned the key and return
151 : * @return const int unique integer value to access the current factory
152 : * @throw invalid argument when key and/or int_key is already taken
153 : */
154 : template <typename T>
155 1881 : const int registerFactory(const PtrFactoryType<T> factory,
156 : const std::string &key = "",
157 : const int int_key = -1) {
158 : FactoryType<T> f = factory;
159 5640 : return registerFactory(f, key, int_key);
160 : }
161 :
162 : /**
163 : * @brief Factory register function, use this function to register custom
164 : * object
165 : *
166 : * @tparam T object to create. Currently Optimizer, Layer is supported
167 : * @param factory factory function that creates std::unique_ptr<T>
168 : * @param key key to access the factory, if key is empty, try to find key by
169 : * calling factory({})->getType();
170 : * @param int_key key to access the factory by integer, if it is -1(default),
171 : * the function automatically unsigned the key and return
172 : * @return const int unique integer value to access the current factory
173 : * @throw invalid argument when key and/or int_key is already taken
174 : */
175 : template <typename T>
176 : const int registerFactory(const FactoryType<T> factory,
177 : const std::string &key = "",
178 : const int int_key = -1);
179 :
180 : std::unique_ptr<nntrainer::Layer>
181 6126 : createLayerObject(const std::string &type,
182 : const std::vector<std::string> &properties = {}) override {
183 6126 : return createObject<nntrainer::Layer>(type, properties);
184 : }
185 :
186 755 : std::unique_ptr<nntrainer::Optimizer> createOptimizerObject(
187 : const std::string &type,
188 : const std::vector<std::string> &properties = {}) override {
189 755 : return createObject<nntrainer::Optimizer>(type, properties);
190 : }
191 :
192 : std::unique_ptr<ml::train::LearningRateScheduler>
193 54 : createLearningRateSchedulerObject(
194 : const std::string &type,
195 : const std::vector<std::string> &properties = {}) override {
196 54 : return createObject<ml::train::LearningRateScheduler>(type, properties);
197 : }
198 :
199 : std::unique_ptr<nntrainer::Layer>
200 91 : createLayerObject(const int int_key,
201 : const std::vector<std::string> &properties = {}) override {
202 91 : return createObject<nntrainer::Layer>(int_key, properties);
203 : }
204 :
205 29 : std::unique_ptr<nntrainer::Optimizer> createOptimizerObject(
206 : const int int_key,
207 : const std::vector<std::string> &properties = {}) override {
208 29 : return createObject<nntrainer::Optimizer>(int_key, properties);
209 : }
210 :
211 : std::unique_ptr<ml::train::LearningRateScheduler>
212 22 : createLearningRateSchedulerObject(
213 : const int int_key,
214 : const std::vector<std::string> &properties = {}) override {
215 22 : return createObject<ml::train::LearningRateScheduler>(int_key, properties);
216 : }
217 :
218 : /**
219 : * @brief Create an Object from the integer key
220 : *
221 : * @tparam T Type of Object, currently, Only optimizer is supported
222 : * @param int_key integer key
223 : * @param props property
224 : * @return PtrType<T> unique pointer to the object
225 : */
226 : template <typename T>
227 147 : PtrType<T> createObject(const int int_key,
228 : const PropsType &props = {}) const {
229 : static_assert(isSupported<T>::value,
230 : "given type is not supported for current app context");
231 : auto &index = std::get<IndexType<T>>(factory_map);
232 : auto &int_map = std::get<IntIndexType>(index);
233 :
234 : const auto &entry = int_map.find(int_key);
235 :
236 147 : if (entry == int_map.end()) {
237 2 : std::stringstream ss;
238 2 : ss << "Int Key is not found for the object. Key: " << int_key;
239 4 : throw exception::not_supported(ss.str().c_str());
240 2 : }
241 :
242 145 : return createObject<T>(entry->second, props);
243 : }
244 :
245 : /**
246 : * @brief Create an Object from the string key
247 : *
248 : * @tparam T Type of object, currently, only optimizer is supported
249 : * @param key integer key
250 : * @param props property
251 : * @return PtrType<T> unique pointer to the object
252 : */
253 : template <typename T>
254 7172 : PtrType<T> createObject(const std::string &key,
255 : const PropsType &props = {}) const {
256 : auto &index = std::get<IndexType<T>>(factory_map);
257 : auto &str_map = std::get<StrIndexType<T>>(index);
258 :
259 : std::string lower_key;
260 : lower_key.resize(key.size());
261 :
262 : std::transform(key.begin(), key.end(), lower_key.begin(),
263 60326 : [](unsigned char c) { return std::tolower(c); });
264 :
265 : const auto &entry = str_map.find(lower_key);
266 :
267 7172 : if (entry == str_map.end()) {
268 36 : std::stringstream ss;
269 : ss << "Key is not found for the object. Key: " << lower_key;
270 72 : throw exception::not_supported(ss.str().c_str());
271 36 : }
272 :
273 7123 : return entry->second(props);
274 : }
275 :
276 : /**
277 : * @brief special factory that throws for unknown
278 : *
279 : * @tparam T object to create
280 : * @param props props to pass, not used
281 : * @throw always throw runtime_error
282 : */
283 : template <typename T>
284 5 : static PtrType<T> unknownFactory(const PropsType &props) {
285 5 : throw std::invalid_argument("cannot create unknown object");
286 : }
287 :
288 0 : std::string getName() override { return "cpu"; }
289 :
290 27 : void setMemAllocator(std::shared_ptr<MemAllocator> mem) {
291 27 : getContextData()->setMemAllocator(mem);
292 27 : }
293 :
294 : private:
295 : /**
296 : * @brief Overriden initialization function
297 : */
298 : void initialize() noexcept override;
299 :
300 : void add_default_object();
301 :
302 : void add_extension_object();
303 :
304 : FactoryMap<nntrainer::Optimizer, nntrainer::Layer,
305 : ml::train::LearningRateScheduler>
306 : factory_map;
307 : std::string working_path_base;
308 :
309 : template <typename Args, typename T> struct isSupportedHelper;
310 :
311 : /**
312 : * @brief supportHelper to check if given type is supported within appcontext
313 : */
314 : template <typename T, typename... Args>
315 : struct isSupportedHelper<T, AppContext::FactoryMap<Args...>> {
316 : static constexpr bool value =
317 : (std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...);
318 : };
319 :
320 : /**
321 : * @brief supportHelper to check if given type is supported within appcontext
322 : */
323 : template <typename T>
324 : struct isSupported : isSupportedHelper<T, decltype(factory_map)> {};
325 : };
326 :
327 : /**
328 : * @copydoc const int AppContext::registerFactory
329 : */
330 : extern template const int AppContext::registerFactory<nntrainer::Optimizer>(
331 : const FactoryType<nntrainer::Optimizer> factory, const std::string &key,
332 : const int int_key);
333 :
334 : /**
335 : * @copydoc const int AppContext::registerFactory
336 : */
337 : extern template const int AppContext::registerFactory<nntrainer::Layer>(
338 : const FactoryType<nntrainer::Layer> factory, const std::string &key,
339 : const int int_key);
340 :
341 : /**
342 : * @copydoc const int AppContext::registerFactory
343 : */
344 : extern template const int
345 : AppContext::registerFactory<ml::train::LearningRateScheduler>(
346 : const FactoryType<ml::train::LearningRateScheduler> factory,
347 : const std::string &key, const int int_key);
348 :
349 : namespace plugin {}
350 :
351 : } // namespace nntrainer
352 :
353 : #endif /* __APP_CONTEXT_H__ */
|