Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 SeungBaek Hong <sb92.hong@samsung.com>
4 : *
5 : * @file gather_layer.cpp
6 : * @date 02 April 2025
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author SeungBaek Hong <sb92.hong@samsung.com>
9 : * @bug It's not implemented operation yet. Just a draft for compilation.
10 : * @brief This is gather layer class (operation layer)
11 : */
12 :
13 : #include "common_properties.h"
14 : #include <gather_layer.h>
15 : #include <nntrainer_error.h>
16 : #include <nntrainer_log.h>
17 : #include <node_exporter.h>
18 : #include <stdexcept>
19 : #include <util_func.h>
20 :
21 : #include <layer_context.h>
22 :
23 : namespace nntrainer {
24 :
25 0 : void GatherLayer::finalize(InitLayerContext &context) {
26 0 : axis = std::get<props::Axis>(gather_props).get();
27 0 : TensorDim inputDim = context.getInputDimensions()[0];
28 0 : TensorDim indexDim = context.getInputDimensions()[1];
29 :
30 0 : if (axis < 1 || axis > 3) {
31 : throw std::invalid_argument(
32 0 : "The axis property of GatherLayer should be between 1 and 3.");
33 : }
34 :
35 0 : if (inputDim[0] != indexDim[0]) {
36 : throw std::invalid_argument(
37 0 : "The batch size of the input and index should be same.");
38 : }
39 :
40 0 : TensorDim outputDim = TensorDim(inputDim);
41 0 : outputDim.setTensorDim(axis, indexDim[axis]);
42 0 : context.setOutputDimensions({outputDim});
43 0 : }
44 :
45 0 : void GatherLayer::forwarding_operation(const Tensor &input, const Tensor &index,
46 : Tensor &output) {
47 : // TODO: implement forwarding operation
48 0 : throw std::runtime_error("forwarding operation is not implemented yet");
49 : }
50 :
51 0 : void GatherLayer::calcDerivative(RunLayerContext &context) {
52 : // TODO: implement derivative calculation
53 0 : throw std::runtime_error("derivative calculation is not implemented yet");
54 : }
55 :
56 0 : void GatherLayer::setProperty(const std::vector<std::string> &values) {
57 0 : auto remain_props = loadProperties(values, gather_props);
58 0 : if (!remain_props.empty()) {
59 : std::string msg = "[GatherLayer] Unknown Layer Properties count " +
60 0 : std::to_string(values.size());
61 0 : throw exception::not_supported(msg);
62 : }
63 0 : }
64 :
65 : } /* namespace nntrainer */
|