LCOV - code coverage report
Current view: top level - nntrainer/layers - centroid_knn.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 89.1 % 55 49
Test Date: 2025-12-14 20:38:17 Functions: 77.8 % 9 7

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file   centroid_knn.cpp
       6              :  * @date   09 Jan 2021
       7              :  * @brief  This file contains the simple nearest neighbor layer
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  * @details This layer takes centroid and calculate l2 distance
      13              :  */
      14              : 
      15              : #include <iostream>
      16              : #include <limits>
      17              : #include <regex>
      18              : #include <sstream>
      19              : 
      20              : #include <centroid_knn.h>
      21              : #include <layer_context.h>
      22              : #include <nntrainer_error.h>
      23              : #include <nntrainer_log.h>
      24              : #include <node_exporter.h>
      25              : #include <tensor.h>
      26              : #include <weight.h>
      27              : 
      28              : namespace nntrainer {
      29              : 
      30              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      31              : 
      32              : enum KNNParams { map, num_samples };
      33              : 
      34            6 : CentroidKNN::CentroidKNN() : Layer(), centroid_knn_props(props::NumClass()) {
      35              :   weight_idx.fill(std::numeric_limits<unsigned>::max());
      36            6 : }
      37              : 
      38           12 : CentroidKNN::~CentroidKNN() {}
      39              : 
      40           15 : void CentroidKNN::setProperty(const std::vector<std::string> &values) {
      41           15 :   auto left = loadProperties(values, centroid_knn_props);
      42           17 :   NNTR_THROW_IF(!left.empty(), std::invalid_argument)
      43              :     << "[Centroid KNN] there are unparsed properties " << left.front();
      44           15 : }
      45              : 
      46            3 : void CentroidKNN::finalize(nntrainer::InitLayerContext &context) {
      47              :   auto const &input_dim = context.getInputDimensions()[0];
      48            3 :   if (input_dim.channel() != 1 || input_dim.height() != 1) {
      49            0 :     ml_logw("centroid nearest layer is designed for flattend feature for now, "
      50              :             "please check");
      51              :   }
      52              : 
      53              :   auto num_class = std::get<props::NumClass>(centroid_knn_props);
      54              : 
      55            3 :   auto output_dim = nntrainer::TensorDim({num_class});
      56            3 :   context.setOutputDimensions({output_dim});
      57              : 
      58              :   /// weight is a distance map that contains centroid of features of each class
      59            3 :   auto map_dim = nntrainer::TensorDim({num_class, input_dim.getFeatureLen()});
      60              : 
      61              :   /// samples seen for the current run to calculate the centroid
      62            3 :   auto samples_seen = nntrainer::TensorDim({num_class});
      63              : 
      64            3 :   weight_idx[KNNParams::map] = context.requestWeight(
      65              :     map_dim, nntrainer::Initializer::ZEROS, nntrainer::WeightRegularizer::NONE,
      66              :     1.0f, 0.0f, "map", false);
      67              : 
      68            3 :   weight_idx[KNNParams::num_samples] = context.requestWeight(
      69              :     samples_seen, nntrainer::Initializer::ZEROS,
      70              :     nntrainer::WeightRegularizer::NONE, 1.0f, 0.0f, "num_samples", false);
      71            3 : }
      72              : 
      73          175 : void CentroidKNN::forwarding(nntrainer::RunLayerContext &context,
      74              :                              bool training) {
      75          175 :   auto &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
      76          175 :   auto &input_ = context.getInput(SINGLE_INOUT_IDX);
      77          175 :   const auto &input_dim = input_.getDim();
      78              : 
      79          175 :   auto &map = context.getWeight(weight_idx[KNNParams::map]);
      80          175 :   auto &num_samples = context.getWeight(weight_idx[KNNParams::num_samples]);
      81          175 :   auto feature_len = input_dim.getFeatureLen();
      82              : 
      83          859 :   auto get_distance = [](const nntrainer::Tensor &a,
      84              :                          const nntrainer::Tensor &b) {
      85          859 :     return -a.subtract(b).l2norm();
      86              :   };
      87              : 
      88          175 :   if (training) {
      89          100 :     auto &label = context.getLabel(SINGLE_INOUT_IDX);
      90          100 :     auto ans = label.argmax();
      91              : 
      92          200 :     for (unsigned int b = 0; b < input_.batch(); ++b) {
      93              :       auto saved_feature =
      94          100 :         map.getSharedDataTensor({feature_len}, ans[b] * feature_len);
      95              : 
      96              :       //  nntrainer::Tensor::Map(map.getData(), {feature_len},
      97              :       // ans[b] * feature_len);
      98          100 :       float num_sample = num_samples.getValue<float>(0, 0, 0, ans[b]);
      99          100 :       auto current_feature = input_.getBatchSlice(b, 1);
     100          100 :       saved_feature.multiply_i(num_sample);
     101          100 :       saved_feature.add_i(current_feature);
     102          100 :       saved_feature.divide_i(num_sample + 1);
     103          100 :       num_samples.setValue(0, 0, 0, ans[b], num_sample + 1);
     104          100 :     }
     105          100 :   }
     106              : 
     107         1050 :   for (unsigned int i = 0; i < std::get<props::NumClass>(centroid_knn_props);
     108              :        ++i) {
     109              :     auto saved_feature =
     110          875 :       map.getSharedDataTensor({feature_len}, i * feature_len);
     111              :     // nntrainer::Tensor::Map(map.getData(), {feature_len}, i * feature_len);
     112              : 
     113          875 :     auto num_sample = num_samples.getValue(0, 0, 0, i);
     114              : 
     115         1750 :     for (unsigned int b = 0; b < input_.batch(); ++b) {
     116          875 :       auto current_feature = input_.getBatchSlice(b, 1);
     117              : 
     118          875 :       if (num_sample == 0) {
     119           16 :         hidden_.setValue(b, 0, 0, i, std::numeric_limits<float>::min());
     120              :       } else {
     121          859 :         hidden_.setValue(b, 0, 0, i,
     122              :                          get_distance(current_feature, saved_feature));
     123              :       }
     124          875 :     }
     125          875 :   }
     126          175 : }
     127              : 
     128            0 : void CentroidKNN::calcDerivative(nntrainer::RunLayerContext &context) {
     129              :   throw std::invalid_argument("[CentroidKNN::calcDerivative] This Layer "
     130            0 :                               "does not support backward propagation");
     131              : }
     132              : 
     133            0 : void CentroidKNN::exportTo(nntrainer::Exporter &exporter,
     134              :                            const ml::train::ExportMethods &method) const {
     135            0 :   exporter.saveResult(centroid_knn_props, method, this);
     136            0 : }
     137              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1