Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file loss_layer.h
6 : * @date 12 June 2020
7 : * @brief This is Loss Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __LOSS_LAYER_H__
15 : #define __LOSS_LAYER_H__
16 : #ifdef __cplusplus
17 :
18 : #include <layer_devel.h>
19 :
20 : #include <tensor.h>
21 :
22 : namespace nntrainer {
23 :
24 : /**
25 : * @class LossLayer
26 : * @brief loss layer
27 : */
28 : class LossLayer : public Layer {
29 : public:
30 : /**
31 : * @brief Destructor of Loss Layer
32 : */
33 760 : virtual ~LossLayer() = default;
34 :
35 : /**
36 : * @copydoc Layer::finalize(InitLayerContext &context)
37 : */
38 : virtual void finalize(InitLayerContext &context) override;
39 :
40 : /**
41 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
42 : */
43 : virtual void setProperty(const std::vector<std::string> &values) override;
44 :
45 : /**
46 : * @copydoc Layer::supportBackwarding()
47 : */
48 1502 : virtual bool supportBackwarding() const override { return true; }
49 :
50 : /**
51 : * @copydoc Layer::requireLabel()
52 : */
53 15073 : bool requireLabel() const override { return true; }
54 :
55 : protected:
56 : /**
57 : * @brief update loss
58 : * @param context Run context to update loss in
59 : * @param l Tensor data to calculate
60 : */
61 : void updateLoss(RunLayerContext &context, const Tensor &l);
62 :
63 : /**
64 : * @brief update return derivative with loss scale
65 : * @param context Run context to update
66 : * @param return_dev Tensor data to calculate
67 : */
68 : void applyLossScale(RunLayerContext &context, Tensor &l);
69 :
70 : Tensor
71 : l; /**< loss tensor to store intermediate value to calculate loss value */
72 : };
73 :
74 : } // namespace nntrainer
75 :
76 : #endif /* __cplusplus */
77 : #endif /* __LOSS_LAYER_H__ */
|