Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kpaoor <pk.kapoor@samsung.com>
4 : *
5 : * @file plugged_optimizer.h
6 : * @date 1 June 2021
7 : * @brief This file contains a wrapper for a plugged optimizer, INTERNAL USE
8 : * ONLY
9 : * @see https://github.com/nnstreamer/nntrainer
10 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
11 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
12 : * @bug No known bugs except for NYI items
13 : *
14 : */
15 :
16 : #ifndef __PLUGGED_OPTIMIZER_H__
17 : #define __PLUGGED_OPTIMIZER_H__
18 :
19 : #include <nntrainer_error.h>
20 : #include <optimizer.h>
21 : #include <optimizer_devel.h>
22 :
23 : namespace nntrainer {
24 : namespace internal {
25 :
26 : /**
27 : * @brief Plugged optimizer class
28 : */
29 : class PluggedOptimizer : public nntrainer::Optimizer {
30 : public:
31 : /**
32 : * @brief Construct a new Plugged Optimizer object
33 : *
34 : * @param pluggable OptimizerPluggable structure from the symbol
35 : */
36 2 : PluggedOptimizer(const nntrainer::OptimizerPluggable *pluggable) :
37 2 : optimizer_devel(
38 2 : dynamic_cast<nntrainer::Optimizer *>(pluggable->createfunc())),
39 2 : destroy_func(pluggable->destroyfunc) {
40 2 : NNTR_THROW_IF(optimizer_devel == nullptr, std::invalid_argument)
41 : << "create_func_ for plugged optimizer failed";
42 2 : }
43 :
44 : /**
45 : * @brief Destroy the Plugged Optimizer object
46 : *
47 : */
48 2 : ~PluggedOptimizer() override { destroy_func(optimizer_devel); }
49 :
50 : /**
51 : * @brief Move Construct Plugged Optimizer object
52 : *
53 : * @param rhs optimizer to move
54 : */
55 : PluggedOptimizer(PluggedOptimizer &&rhs) noexcept = default;
56 :
57 : /**
58 : * @brief Move assign Plugged Optimizer Object
59 : *
60 : * @param rhs optimizer to move
61 : * @return PluggedOptimizer& *this
62 : */
63 : PluggedOptimizer &operator=(PluggedOptimizer &&rhs) = default;
64 :
65 : /**
66 : * @copydoc Optimizer::getDefaultLearningRate()
67 : *
68 : */
69 0 : double getDefaultLearningRate() const override {
70 0 : return optimizer_devel->getDefaultLearningRate();
71 : }
72 : /**
73 : * @brief apply gradient to weight
74 : * @param[in] context Optimizer context
75 : */
76 0 : void applyGradient(RunOptimizerContext &context) override {
77 0 : optimizer_devel->applyGradient(context);
78 0 : }
79 :
80 : /**
81 : * @brief set Optimizer Parameters
82 : * @param[in] values Optimizer Parameter list
83 : * @retval #ML_ERROR_NONE Successful.
84 : * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
85 : */
86 0 : void setProperty(const std::vector<std::string> &values) override {
87 0 : optimizer_devel->setProperty(values);
88 0 : }
89 :
90 : /**
91 : * @brief finalize optimizer.
92 : */
93 0 : void finalize() override { optimizer_devel->finalize(); }
94 :
95 : /**
96 : * @brief Read Training optimizer parameters from file
97 : * @param[in] file input stream file
98 : */
99 0 : void read(std::ifstream &file) override { optimizer_devel->read(file); }
100 :
101 : /**
102 : * @brief Save Training optimizer parameters from file
103 : * @param[in] file output stream file
104 : */
105 0 : void save(std::ofstream &file) override { optimizer_devel->save(file); }
106 :
107 : /**
108 : * @brief Get dimension of extra variables if the optimizer needs any.
109 : * @param dim Dimension of tensor to be added as a optimizer variable
110 : * @return Vector of dimensions
111 : */
112 : virtual std::vector<TensorDim>
113 0 : getOptimizerVariableDim(const TensorDim &dim) override {
114 0 : return optimizer_devel->getOptimizerVariableDim(dim);
115 : }
116 :
117 : /**
118 : * @brief get Optimizer Type
119 : * @retval Optimizer type
120 : */
121 2 : const std::string getType() const override {
122 2 : return optimizer_devel->getType();
123 : }
124 :
125 : private:
126 : nntrainer::Optimizer *optimizer_devel;
127 : nntrainer::DestroyOptimizerFunc destroy_func;
128 : };
129 :
130 : } // namespace internal
131 : } // namespace nntrainer
132 :
133 : #endif // __PLUGGED_OPTIMIZER_H__
|