Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file optimizer_devel.h
6 : * @date 08 April 2020
7 : * @brief This is Optimizer internal interface class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : *
13 : */
14 :
15 : #ifndef __OPTIMIZER_DEVEL_H__
16 : #define __OPTIMIZER_DEVEL_H__
17 : #ifdef __cplusplus
18 :
19 : #include <memory>
20 :
21 : #include <optimizer.h>
22 : #include <optimizer_context.h>
23 : #include <tensor.h>
24 :
25 : namespace nntrainer {
26 :
27 : class Exporter;
28 :
29 : /**
30 : * @class Optimizer Base class for optimizers
31 : * @brief Base class for all optimizers
32 : */
33 : class Optimizer {
34 :
35 : public:
36 : /**
37 : * @brief Destructor of Optimizer Class
38 : */
39 : virtual ~Optimizer() = default;
40 :
41 : /**
42 : * @brief get Learning Rate
43 : * @retval Learning rate in float
44 : */
45 : virtual double getDefaultLearningRate() const = 0;
46 :
47 : /**
48 : * @brief apply gradient to weight
49 : * @param[in] context Optimizer context
50 : */
51 : virtual void applyGradient(RunOptimizerContext &context) = 0;
52 :
53 : /**
54 : * @brief set Optimizer Parameters
55 : * @param[in] values Optimizer Parameter list
56 : */
57 : virtual void setProperty(const std::vector<std::string> &values);
58 :
59 : /**
60 : * @brief this function helps exporting the optimizer in a predefined format,
61 : * while workarounding issue caused by templated function type eraser
62 : *
63 : * @param exporter exporter that contains exporting logic
64 : * @param method enum value to identify how it should be exported to
65 : */
66 235 : virtual void exportTo(Exporter &exporter,
67 235 : const ml::train::ExportMethods &method) const {}
68 :
69 : /**
70 : * @brief finalize optimizer.
71 : */
72 616 : virtual void finalize(){};
73 :
74 : /**
75 : * @brief Read Training optimizer parameters from file
76 : * @param[in] file input stream file
77 : */
78 : virtual void read(std::ifstream &file);
79 :
80 : /**
81 : * @brief Save Training optimizer parameters from file
82 : * @param[in] file output stream file
83 : */
84 : virtual void save(std::ofstream &file);
85 :
86 : /**
87 : * @brief Get dimension of extra variables if the optimizer needs any.
88 : * @param dim Dimension of tensor to be added as a optimizer variable
89 : * @return Vector of dimensions
90 : */
91 : virtual std::vector<TensorDim>
92 : getOptimizerVariableDim(const TensorDim &dim) = 0;
93 :
94 : /**
95 : * @brief get Optimizer Type
96 : * @retval Optimizer type
97 : */
98 : virtual const std::string getType() const = 0;
99 : };
100 :
101 : using CreateOptimizerFunc = nntrainer::Optimizer *(*)();
102 : using DestroyOptimizerFunc = void (*)(nntrainer::Optimizer *);
103 :
104 : /**
105 : * @brief General Optimizer Factory function to register Optimizer
106 : *
107 : * @param props property representation
108 : * @return std::unique_ptr<nntrainer::Optimizer> created object
109 : */
110 : template <typename T,
111 : std::enable_if_t<std::is_base_of<Optimizer, T>::value, T> * = nullptr>
112 : std::unique_ptr<Optimizer>
113 778 : createOptimizer(const std::vector<std::string> &props = {}) {
114 229 : std::unique_ptr<Optimizer> ptr = std::make_unique<T>();
115 778 : ptr->setProperty(props);
116 770 : return ptr;
117 : }
118 :
119 : /**
120 : * @brief Optimizer Pluggable struct that enables pluggable layer
121 : *
122 : */
123 : typedef struct {
124 : CreateOptimizerFunc createfunc; /**< create function */
125 : DestroyOptimizerFunc destroyfunc; /**< destroy function */
126 : } OptimizerPluggable;
127 :
128 : /**
129 : * @brief pluggable optimizer must have this structure defined
130 : */
131 : extern "C" OptimizerPluggable ml_train_optimizer_pluggable;
132 :
133 : } /* namespace nntrainer */
134 :
135 : #endif /* __cplusplus */
136 : #endif /* __OPTIMIZER_DEVEL_H__ */
|