Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file context.h
6 : * @date 10 Dec 2024
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This file contains app context related functions and classes that
11 : * manages the global configuration of the current environment.
12 : */
13 :
14 : #ifndef __CONTEXT_H__
15 : #define __CONTEXT_H__
16 :
17 : #include <algorithm>
18 : #include <functional>
19 : #include <memory>
20 : #include <mutex>
21 : #include <sstream>
22 : #include <stdexcept>
23 : #include <string>
24 : #include <type_traits>
25 : #include <unordered_map>
26 : #include <vector>
27 :
28 : #include <context.h>
29 : #include <layer.h>
30 : #include <layer_devel.h>
31 : #include <mem_allocator.h>
32 : #include <optimizer.h>
33 : #include <optimizer_devel.h>
34 :
35 : #include <nntrainer_log.h>
36 :
37 : namespace nntrainer {
38 :
39 : /**
40 : * @class ContextData contains Data which generated by Context to use in Layers
41 : * created by Context
42 : * @brief Container to hold context data generated by Context. Eventaully
43 : * RunLayerContext in layer_node will hold the ContextData, so that, Layer can
44 : * access this Context Data.
45 : */
46 : class ContextData {
47 : public:
48 57 : ContextData() = default;
49 57 : virtual ~ContextData() = default;
50 :
51 : std::shared_ptr<MemAllocator> getMemAllocator() { return mem_allocator; }
52 :
53 : void setMemAllocator(std::shared_ptr<MemAllocator> m) { mem_allocator = m; }
54 :
55 : private:
56 : std::shared_ptr<MemAllocator> mem_allocator = nullptr;
57 : };
58 :
59 : /**
60 : * @class Context contains user-dependent configuration for support
61 : * @brief support for app context
62 : */
63 :
64 : class Context {
65 : public:
66 : using PropsType = std::vector<std::string>;
67 :
68 : template <typename T> using PtrType = std::unique_ptr<T>;
69 :
70 : template <typename T>
71 : using FactoryType = std::function<PtrType<T>(const PropsType &)>;
72 :
73 : template <typename T>
74 : using PtrFactoryType = PtrType<T> (*)(const PropsType &);
75 :
76 : template <typename T>
77 : using StrIndexType = std::unordered_map<std::string, FactoryType<T>>;
78 :
79 : /** integer to string key */
80 : using IntIndexType = std::unordered_map<int, std::string>;
81 :
82 : /**
83 : * This type contains tuple of
84 : * 1) integer -> string index
85 : * 2) string -> factory index
86 : */
87 : template <typename T>
88 : using IndexType = std::tuple<StrIndexType<T>, IntIndexType>;
89 :
90 : template <typename... Ts> using FactoryMap = std::tuple<IndexType<Ts>...>;
91 :
92 : /**
93 : * @brief Default constructor
94 : */
95 57 : Context(std::shared_ptr<ContextData> data_ = nullptr) : data(data_) {}
96 :
97 : /**
98 : * @brief Destructor
99 : */
100 57 : virtual ~Context() = default;
101 :
102 : /**
103 : *
104 : * @brief Initialization of Context.
105 : *
106 : * @return status &
107 : */
108 0 : virtual int init() { return 0; };
109 :
110 : /**
111 : * @brief Create an Layer Object from the type (string)
112 : *
113 : * @param type type of layer
114 : * @param props property
115 : * @return PtrType<nntrainer::Layer> unique pointer to the object
116 : */
117 : virtual PtrType<nntrainer::Layer>
118 0 : createLayerObject(const std::string &type,
119 : const std::vector<std::string> &props = {}) {
120 0 : ml_logw(
121 : "[Warning] Implement createLayerObject for the concrete context class to "
122 : "properly create the layer");
123 0 : return nullptr;
124 : };
125 :
126 : /**
127 : * @brief Create an Layer Object from the integer key
128 : *
129 : * @param int_key integer key
130 : * @param props property
131 : * @return PtrType<nntrainer::Layer> unique pointer to the object
132 : */
133 : virtual PtrType<nntrainer::Layer>
134 0 : createLayerObject(const int int_key,
135 : const std::vector<std::string> &props = {}) {
136 0 : ml_logw(
137 : "[Warning] Implement createLayerObject for the concrete context class to "
138 : "properly create the layer");
139 0 : return nullptr;
140 : };
141 :
142 : /**
143 : * @brief Create an Optimizer Object from the type (string)
144 : *
145 : * @param type type of optimizer
146 : * @param props property
147 : * @return PtrType<nntrainer::Optimizer> unique pointer to the object
148 : */
149 : virtual PtrType<nntrainer::Optimizer>
150 0 : createOptimizerObject(const std::string &type,
151 : const std::vector<std::string> &props = {}) {
152 0 : return nullptr;
153 : };
154 :
155 : /**
156 : * @brief Create an Layer Object from the integer key
157 : *
158 : * @param int_key integer key
159 : * @param props property
160 : * @return PtrType<nntrainer::Optimizer> unique pointer to the object
161 : */
162 : virtual PtrType<nntrainer::Optimizer>
163 0 : createOptimizerObject(const int int_key,
164 : const std::vector<std::string> &properties = {}) {
165 0 : return nullptr;
166 : };
167 :
168 : /**
169 : * @brief Create an LearningRateScheduler Object from the type (stirng)
170 : *
171 : * @param type type of optimizer
172 : * @param props property
173 : * @return PtrType<ml::train::LearningRateScheduler> unique pointer to the
174 : * object
175 : */
176 : virtual PtrType<ml::train::LearningRateScheduler>
177 0 : createLearningRateSchedulerObject(
178 : const std::string &type, const std::vector<std::string> &propeties = {}) {
179 0 : return nullptr;
180 : }
181 :
182 : /**
183 : * @brief Create an LearningRateScheduler Object from the integer key
184 : *
185 : * @param int_key integer key
186 : * @param props property
187 : * @return PtrType<ml::train::LearningRateScheduler> unique pointer to the
188 : * object
189 : */
190 : virtual std::unique_ptr<ml::train::LearningRateScheduler>
191 0 : createLearningRateSchedulerObject(
192 : const int int_key, const std::vector<std::string> &propeties = {}) {
193 0 : return nullptr;
194 : }
195 :
196 : /**
197 : * @brief getter of context name
198 : *
199 : * @return string name of the context
200 : */
201 : virtual std::string getName() = 0;
202 :
203 : std::shared_ptr<ContextData> getContextData() { return data; }
204 :
205 27 : std::shared_ptr<MemAllocator> getMemAllocator() {
206 27 : return getContextData()->getMemAllocator();
207 : };
208 :
209 : /**
210 : * @brief load weight and graph for the specific context
211 : *
212 : * @return return 0 for success
213 : */
214 0 : virtual int load(const std::string &file_path) { return 0; };
215 :
216 : private:
217 : /**
218 : * @brief map of context
219 : */
220 : static inline std::unordered_map<std::string, Context *> ContextMap;
221 :
222 : std::shared_ptr<ContextData> data = nullptr;
223 : };
224 :
225 : using CreateContextFunc = nntrainer::Context *(*)();
226 : using DestroyContextFunc = void (*)(nntrainer::Context *);
227 :
228 : /**
229 : * @brief Context Pluggable struct that enables pluggable layer
230 : *
231 : */
232 : typedef struct {
233 : CreateContextFunc createfunc; /**< create layer function */
234 : DestroyContextFunc destroyfunc; /**< destory function */
235 : } ContextPluggable;
236 :
237 : /**
238 : * @brief pluggable Context must have this structure defined
239 : */
240 : extern "C" ContextPluggable ml_train_context_pluggable;
241 :
242 : } // namespace nntrainer
243 :
244 : #endif /* __CONTEXT_H__ */
|