Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 seongwoo <mhs4670go@naver.com>
4 : *
5 : * @file tflite_export_realizer.h
6 : * @date 18 July 2025
7 : * @brief NNTrainer graph realizer which remove loss layer for inference
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author seongwoo <mhs4670go@naver.com>
10 : * @author donghak park <donghak.park@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 : #ifndef __TFLITE_EXPORT_REALIZER_H__
14 : #define __TFLITE_EXPORT_REALIZER_H__
15 :
16 : #include <vector>
17 :
18 : #include <realizer.h>
19 :
20 : namespace nntrainer {
21 :
22 : /**
23 : * @brief Graph realizer class which removes loss layer from the graph
24 : * @note This assumes the number of input / output connection of loss layer == 1
25 : *
26 : */
27 : class TfliteExportRealizer final : public GraphRealizer {
28 : public:
29 : /**
30 : * @brief Construct a new Loss Realizer object
31 : *
32 : */
33 5 : TfliteExportRealizer() = default;
34 :
35 : /**
36 : * @brief Destroy the Graph Realizer object
37 : *
38 : */
39 6 : ~TfliteExportRealizer() = default;
40 :
41 : /**
42 : * @brief graph realizer creates a shallow copied graph based on the reference
43 : * @note loss realizer removes loss layers from GraphRepresentation
44 : * @param reference GraphRepresentation to be realized
45 : *
46 : */
47 : GraphRepresentation realize(const GraphRepresentation &reference) override;
48 :
49 : /**
50 : * @brief graph realizer creates a shallow copied graph based on the reference
51 : * @note drop_out realizer removes drop_out layers from GraphRepresentation
52 : * @param reference GraphRepresentation to be realized
53 : *
54 : */
55 : GraphRepresentation realize_dropout(const GraphRepresentation &reference);
56 : };
57 :
58 : } // namespace nntrainer
59 :
60 : #endif // __TFLITE_EXPORT_REALIZER_H__
|