LCOV - code coverage report
Current view: top level - nntrainer/layers - bn_layer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 65.3 % 225 147
Test Date: 2025-12-14 20:38:17 Functions: 72.7 % 11 8

            Line data    Source code
       1              : /**
       2              :  * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
       3              :  *
       4              :  * Licensed under the Apache License, Version 2.0 (the "License");
       5              :  * you may not use this file except in compliance with the License.
       6              :  * You may obtain a copy of the License at
       7              :  *   http://www.apache.org/licenses/LICENSE-2.0
       8              :  * Unless required by applicable law or agreed to in writing, software
       9              :  * distributed under the License is distributed on an "AS IS" BASIS,
      10              :  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      11              :  * See the License for the specific language governing permissions and
      12              :  * limitations under the License.
      13              :  *
      14              :  *
      15              :  * @file        bn_layer.cpp
      16              :  * @date        14 May 2020
      17              :  * @brief       This is Batch Normalization Layer Class for Neural Network
      18              :  * @see         https://github.com/nnstreamer/nntrainer
      19              :  * @author      Jijoong Moon <jijoong.moon@samsung.com>
      20              :  * @bug         No known bugs except for NYI items
      21              :  *
      22              :  */
      23              : 
      24              : #include <bn_layer.h>
      25              : #include <layer_context.h>
      26              : #include <lazy_tensor.h>
      27              : #include <nntrainer_error.h>
      28              : #include <nntrainer_log.h>
      29              : #include <node_exporter.h>
      30              : #include <util_func.h>
      31              : 
      32              : namespace nntrainer {
      33              : 
      34              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      35              : 
      36              : enum BNParams {
      37              :   mu,
      38              :   var,
      39              :   gamma,
      40              :   beta,
      41              :   mu_b,
      42              :   var_b,
      43              :   deviation,
      44              :   invstd,
      45              :   cvar,
      46              :   t_reduced,
      47              :   t_full
      48              : };
      49              : 
      50           56 : BatchNormalizationLayer::BatchNormalizationLayer() :
      51              :   Layer(),
      52           56 :   divider(0),
      53           56 :   bn_props(props::Epsilon(), props::MuInitializer(), props::VarInitializer(),
      54           56 :            props::BetaInitializer(), props::GammaInitializer(),
      55          112 :            props::Momentum(), props::Axis(), props::WeightDecay(),
      56          112 :            props::BiasDecay()) {
      57              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      58           56 : }
      59              : 
      60              : /// @todo add multiple axis support
      61           27 : void BatchNormalizationLayer::finalize(InitLayerContext &context) {
      62           27 :   NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
      63              :     << "Only one input is allowed for batch normalization layer";
      64              : 
      65              :   auto &bnparams_mu = std::get<props::MuInitializer>(bn_props);
      66              :   auto &bnparams_var = std::get<props::VarInitializer>(bn_props);
      67              :   auto &bnparams_beta = std::get<props::BetaInitializer>(bn_props);
      68              :   auto &bnparams_gamma = std::get<props::GammaInitializer>(bn_props);
      69              :   auto &weight_decay = std::get<props::WeightDecay>(bn_props);
      70              :   auto &bias_decay = std::get<props::BiasDecay>(bn_props);
      71              : 
      72              :   /** set output dimensions */
      73              :   auto const &in_dim = context.getInputDimensions()[0];
      74           27 :   context.setOutputDimensions(context.getInputDimensions());
      75              : 
      76           27 :   TensorDim dim(context.getFormat(), context.getWeightDataType());
      77              : 
      78           27 :   if (context.getExecutionMode() == ml::train::ExecutionMode::TRAIN) {
      79              :     dim.setDataType(TensorDim::DataType::FP32);
      80              :   }
      81              : 
      82              :   /// @note this logic cannot tell channel is actually 1 or it is just not used.
      83              :   auto &axis_prop = std::get<props::Axis>(bn_props);
      84              :   unsigned int axis;
      85           27 :   if (axis_prop.empty())
      86           27 :     axis = in_dim.channel() > 1 ? 1 : 3;
      87              :   else
      88            0 :     axis = axis_prop.get();
      89              : 
      90              :   /**
      91              :    * @todo This can be speedup by employing transpose for convolution. With
      92              :    * transpose, the channel dimension can be made last, and the remaining
      93              :    * dimensions can be squeezed. This would allow the sum and average to be
      94              :    * faster, and no temporary allocations inside them.
      95              :    */
      96              : 
      97           27 :   dim.setTensorDim(axis, in_dim.getTensorDim(axis));
      98              : 
      99           27 :   divider = 1;
     100          135 :   for (unsigned int i = 0; i < 4; ++i) {
     101          108 :     if (axis != i) {
     102           81 :       axes_to_reduce.push_back(i);
     103           81 :       divider *= in_dim.getTensorDim(i);
     104              :     }
     105              :   }
     106              : 
     107           27 :   wt_idx[BNParams::mu] =
     108           27 :     context.requestWeight(dim, dim, bnparams_mu, WeightRegularizer::NONE, 1.0f,
     109              :                           0.0f, "moving_mean", false);
     110           27 :   wt_idx[BNParams::var] =
     111           27 :     context.requestWeight(dim, dim, bnparams_var, WeightRegularizer::NONE, 1.0f,
     112              :                           0.0f, "moving_variance", false);
     113           27 :   wt_idx[BNParams::gamma] =
     114           27 :     context.requestWeight(dim, dim, bnparams_gamma, WeightRegularizer::NONE,
     115              :                           1.0f, weight_decay, "gamma", true);
     116           27 :   wt_idx[BNParams::beta] =
     117           27 :     context.requestWeight(dim, dim, bnparams_beta, WeightRegularizer::NONE,
     118              :                           1.0f, bias_decay, "beta", true);
     119              : 
     120           27 :   wt_idx[BNParams::mu_b] =
     121           27 :     context.requestTensor(dim, "moviing_mean_backup", Initializer::NONE, false,
     122              :                           TensorLifespan::ITERATION_LIFESPAN);
     123              : 
     124           27 :   wt_idx[BNParams::var_b] =
     125           27 :     context.requestTensor(dim, "moviing_variance_backup", Initializer::NONE,
     126              :                           false, TensorLifespan::ITERATION_LIFESPAN);
     127              : 
     128              :   /**
     129              :    * caches the deviation -> input - avg(input)
     130              :    * @todo check if avoiding this storage and adding dependency on input (no
     131              :    * more in-place calculation) can save memory during memory optimization.
     132              :    */
     133           27 :   TensorDim in_dim_ = in_dim;
     134              : 
     135           27 :   if (context.getExecutionMode() == ml::train::ExecutionMode::TRAIN) {
     136              :     in_dim_.setDataType(TensorDim::DataType::FP32);
     137              :   }
     138              : 
     139           27 :   wt_idx[BNParams::deviation] =
     140           27 :     context.requestTensor(in_dim_, "deviation", Initializer::NONE, false,
     141              :                           TensorLifespan::ITERATION_LIFESPAN);
     142              :   /** caches the inverse standard deviation */
     143           27 :   wt_idx[BNParams::invstd] =
     144           27 :     context.requestTensor(dim, "invstd", Initializer::NONE, false,
     145              :                           TensorLifespan::ITERATION_LIFESPAN);
     146              :   /**
     147              :    * Temporary tensor to store the full sized tensors in order to allow batch
     148              :    * norm to execute in-place. Running in-place leads to same memory footprint
     149              :    * for this layer in its backwarding, but reduces the peak memory requirement
     150              :    * as the output of this layer need not be stored all the time.
     151              :    */
     152           27 :   wt_idx[BNParams::t_full] =
     153           27 :     context.requestTensor(in_dim_, "tensor_full", Initializer::NONE, false,
     154              :                           TensorLifespan::CALC_DERIV_LIFESPAN);
     155              :   /**
     156              :    * caches variance + epsilon as well.
     157              :    */
     158           27 :   wt_idx[BNParams::cvar] = context.requestTensor(
     159              :     dim, "cvar", Initializer::NONE, false, TensorLifespan::ITERATION_LIFESPAN);
     160              :   /**
     161              :    * Temporary tensor to store the reduced tensors along the axes_to_reduce.
     162              :    */
     163           27 :   wt_idx[BNParams::t_reduced] =
     164           27 :     context.requestTensor(dim, "tensor_reduced", Initializer::NONE, false,
     165              :                           TensorLifespan::FORWARD_DERIV_LIFESPAN);
     166           27 : }
     167              : 
     168          199 : void BatchNormalizationLayer::setProperty(
     169              :   const std::vector<std::string> &values) {
     170          199 :   auto remain_props = loadProperties(values, bn_props);
     171          197 :   NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
     172            2 :     << "[BNLayer] Unknown Layer Properties count " +
     173            4 :          std::to_string(values.size());
     174          197 : }
     175              : 
     176           65 : void BatchNormalizationLayer::forwarding(RunLayerContext &context,
     177              :                                          bool training) {
     178           65 :   float epsilon = std::get<props::Epsilon>(bn_props);
     179           65 :   float momentum = std::get<props::Momentum>(bn_props);
     180              : 
     181           65 :   Tensor &mu = context.getWeight(wt_idx[BNParams::mu]);
     182           65 :   Tensor &var = context.getWeight(wt_idx[BNParams::var]);
     183           65 :   Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
     184           65 :   Tensor &beta = context.getWeight(wt_idx[BNParams::beta]);
     185              : 
     186          130 :   Tensor em_input, em_hidden;
     187              : 
     188              :   Tensor &input_ = em_input;
     189              :   Tensor &hidden_ = em_hidden;
     190              : 
     191           65 :   if (training) {
     192           43 :     if (context.getInput(SINGLE_INOUT_IDX).getDataType() !=
     193              :         TensorDim::DataType::FP32) {
     194              :       input_ =
     195            0 :         context.getInput(SINGLE_INOUT_IDX).clone(TensorDim::DataType::FP32);
     196              :     } else {
     197           43 :       input_ = context.getInput(SINGLE_INOUT_IDX);
     198              :     }
     199              : 
     200           43 :     if (context.getOutput(SINGLE_INOUT_IDX).getDataType() !=
     201              :         TensorDim::DataType::FP32) {
     202              :       hidden_ =
     203            0 :         context.getOutput(SINGLE_INOUT_IDX).clone(TensorDim::DataType::FP32);
     204              :     } else {
     205           43 :       hidden_ = context.getOutput(SINGLE_INOUT_IDX);
     206              :     }
     207              :   } else {
     208           22 :     input_ = context.getInput(SINGLE_INOUT_IDX);
     209           22 :     hidden_ = context.getOutput(SINGLE_INOUT_IDX);
     210              :   }
     211              : 
     212           65 :   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
     213           65 :   Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]);
     214              : 
     215              :   /** @todo these are not needed for inference, support optimizing these */
     216           65 :   Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
     217              :   /** use hidden_ as temporary tensor before setting the result in hidden */
     218           65 :   Tensor t_full = hidden_;
     219           65 :   Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
     220              : 
     221           65 :   if (training) {
     222              : 
     223           43 :     Tensor &mu_b = context.getTensor(wt_idx[BNParams::mu_b]);
     224           43 :     Tensor &var_b = context.getTensor(wt_idx[BNParams::var_b]);
     225              : 
     226           43 :     if (context.reStoreData()) {
     227            0 :       mu.copyData(mu_b);
     228            0 :       var.copyData(var_b);
     229            0 :       deviation.setZero();
     230            0 :       invstd.setZero();
     231            0 :       t_reduced.setZero();
     232            0 :       cvar.setZero();
     233              :     } else {
     234           43 :       mu_b.copyData(mu);
     235           43 :       var_b.copyData(var);
     236              :     }
     237              : 
     238           43 :     input_.average(axes_to_reduce, t_reduced);
     239           43 :     input_.subtract(t_reduced, deviation);
     240              : 
     241           43 :     mu.multiply_i(momentum);
     242           43 :     mu.add_i(t_reduced, 1 - momentum);
     243              : 
     244           43 :     deviation.pow(2.0f, t_full);
     245           43 :     t_full.average(axes_to_reduce, cvar);
     246              : 
     247           43 :     var.multiply_i(momentum);
     248           43 :     var.add_i(cvar, 1 - momentum);
     249              : 
     250           43 :     cvar.add_i(epsilon);
     251           43 :     cvar.pow(-0.5f, invstd);
     252              :   } else {
     253           22 :     input_.subtract(mu, deviation);
     254              :     /** @todo do below 2 lines only for first iteration */
     255           22 :     var.add(epsilon, invstd);
     256           22 :     invstd.pow_i(-0.5f);
     257              :   }
     258              : 
     259           65 :   deviation.multiply(invstd, hidden_);
     260           65 :   hidden_.multiply_i(gamma);
     261           65 :   hidden_.add_i(beta);
     262              : 
     263          108 :   if (training && hidden_.getDataType() !=
     264           43 :                     context.getOutput(SINGLE_INOUT_IDX).getDataType())
     265            0 :     context.getOutput(SINGLE_INOUT_IDX).copyData(hidden_);
     266           65 : }
     267              : 
     268           35 : void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) {
     269              : 
     270           35 :   Tensor &gamma = context.getWeight(wt_idx[BNParams::gamma]);
     271              : 
     272           70 :   Tensor em_dx, deriv32;
     273              :   bool deriv_copyed = false;
     274              : 
     275           35 :   const Tensor deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
     276              : 
     277           35 :   if (deriv.getDataType() != TensorDim::DataType::FP32) {
     278              :     deriv_copyed = true;
     279            0 :     TensorDim dim = deriv.getDim();
     280              :     dim.setDataType(TensorDim::DataType::FP32);
     281            0 :     deriv32 = Tensor(dim, true);
     282            0 :     deriv32.copyData(deriv);
     283              :   }
     284              : 
     285           35 :   Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() ==
     286              :                    TensorDim::DataType::FP32
     287           35 :                  ? context.getOutgoingDerivative(SINGLE_INOUT_IDX)
     288              :                  : em_dx;
     289              : 
     290           35 :   if (dx.empty())
     291            0 :     dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX)
     292            0 :            .clone(TensorDim::DataType::FP32);
     293              : 
     294           35 :   Tensor &deviation = context.getTensor(wt_idx[BNParams::deviation]);
     295           35 :   Tensor &invstd = context.getTensor(wt_idx[BNParams::invstd]);
     296           35 :   Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);
     297              : 
     298           35 :   Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]);
     299           35 :   Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]);
     300              : 
     301           35 :   t_full.setZero();
     302              : 
     303           70 :   deviation.multiply((deriv_copyed ? deriv32 : deriv), t_full);
     304           35 :   t_full.average(axes_to_reduce, t_reduced);
     305           35 :   t_reduced.divide_i(cvar);
     306           35 :   deviation.multiply_i(t_reduced);
     307              : 
     308           35 :   if (context.getTrainable()) {
     309              :     /**
     310              :      * This calculates dgamma tensor.
     311              :      */
     312           35 :     Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
     313           35 :     t_full.multiply_i(invstd);
     314           35 :     t_full.sum(axes_to_reduce, dgamma);
     315              : 
     316              :     /**
     317              :      * This implementation depends on the pre-calculated dbeta calculated.
     318              :      */
     319           35 :     Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
     320           35 :     dbeta.divide(divider, t_reduced);
     321              :   } else {
     322            0 :     (deriv_copyed ? deriv32 : deriv).average(axes_to_reduce, t_reduced);
     323              :   }
     324              : 
     325           35 :   (deriv_copyed ? deriv32 : deriv).subtract(t_reduced, dx);
     326           35 :   dx.subtract_i(deviation);
     327              : 
     328           35 :   invstd.multiply_i(gamma);
     329           35 :   dx.multiply_i(invstd);
     330              : 
     331           70 :   if (dx.getDataType() !=
     332           35 :       context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType())
     333            0 :     context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(dx);
     334           35 : }
     335              : 
     336           35 : void BatchNormalizationLayer::calcGradient(RunLayerContext &context) {
     337              :   /** dgamma is calculated in calcDerivative. dbeta is calculated here */
     338           35 :   Tensor &dbeta = context.getWeightGrad(wt_idx[BNParams::beta]);
     339           35 :   dbeta.setZero();
     340              : 
     341           35 :   Tensor deriv32;
     342              :   bool deriv_copyed = false;
     343              : 
     344           35 :   const Tensor deriv = context.getIncomingDerivative(SINGLE_INOUT_IDX);
     345           35 :   if (deriv.getDataType() != TensorDim::DataType::FP32) {
     346              :     deriv_copyed = true;
     347            0 :     TensorDim dim = deriv.getDim();
     348              :     dim.setDataType(TensorDim::DataType::FP32);
     349            0 :     deriv32 = Tensor(dim, true);
     350            0 :     deriv32.copyData(deriv);
     351              :   }
     352              : 
     353           35 :   (deriv_copyed ? deriv32 : deriv).sum(axes_to_reduce, dbeta);
     354           35 : }
     355              : 
     356            6 : void BatchNormalizationLayer::exportTo(
     357              :   Exporter &exporter, const ml::train::ExportMethods &method) const {
     358            6 :   exporter.saveResult(bn_props, method, this);
     359            6 : }
     360              : 
     361           19 : void BatchNormalizationLayer::setBatch(RunLayerContext &context,
     362              :                                        unsigned int batch) {
     363           19 :   context.updateTensor(wt_idx[BNParams::deviation], batch);
     364           19 :   context.updateTensor(wt_idx[BNParams::t_full], batch);
     365              : 
     366              :   /// reset divider
     367           19 :   divider = 1;
     368           19 :   auto input_dim = context.getInput(0).getDim();
     369           76 :   for (auto axis : axes_to_reduce) {
     370           57 :     if (axis == 0) {
     371              :       /// @note input dim batch is not updated, it will be more sensible we
     372              :       /// update batch before any node comes to this spot
     373           19 :       divider *= batch;
     374              :     }
     375           57 :     divider *= input_dim.getTensorDim(axis);
     376              :   }
     377           19 : }
     378              : 
     379            0 : void BatchNormalizationLayer::save(
     380              :   std::ofstream &file, RunLayerContext &run_context, bool opt_var,
     381              :   ml::train::ExecutionMode mode, bool trainable,
     382              :   TensorDim::DataType definedWeightDataType) const {
     383            0 :   if (opt_var) {
     384            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     385            0 :       if (run_context.isGradientFirstAccess(i) && trainable) {
     386              :         // @note save optimizer variables
     387            0 :         if (run_context.weightHasGradient(i)) {
     388            0 :           for (unsigned int j = 0; j < run_context.getNumWeightOptVar(i); ++j) {
     389            0 :             run_context.getWeightOptVar(i, j).save(file);
     390              :           }
     391              :         }
     392              :       }
     393              :     }
     394              :   } else {
     395              :     // @note shared weights are only be saved at the first access
     396            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     397            0 :       if (run_context.isGradientFirstAccess(i)) {
     398              : 
     399              :         // @note For batch normalization layer, we do need full precision for
     400              :         // training and the data type of weight is full precision. But for
     401              :         // inference, We do have to save them as activation data type.
     402            0 :         if ((mode == ml::train::ExecutionMode::TRAIN) &&
     403            0 :             (definedWeightDataType != TensorDim::DataType::FP32)) {
     404            0 :           TensorDim dim = run_context.getWeight(i).getDim();
     405              : 
     406              :           dim.setDataType(definedWeightDataType);
     407              : 
     408            0 :           Tensor T_save(dim, true);
     409              : 
     410            0 :           T_save.copyData(run_context.getWeight(i));
     411              : 
     412            0 :           T_save.save(file);
     413            0 :         } else {
     414            0 :           run_context.getWeight(i).save(file);
     415              :         }
     416              :       }
     417              :     }
     418              :   }
     419            0 : }
     420              : 
     421            0 : void BatchNormalizationLayer::read(std::ifstream &file,
     422              :                                    RunLayerContext &run_context, bool opt_var,
     423              :                                    ml::train::ExecutionMode mode,
     424              :                                    bool trainable,
     425              :                                    TensorDim::DataType definedWeightDataType,
     426              :                                    bool fsu, size_t start_offset,
     427              :                                    bool read_from_offset, int file_fd) {
     428            0 :   if (opt_var) {
     429            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     430            0 :       if (run_context.isGradientLastAccess(i) && trainable) {
     431              :         /// @note read optimizer variables
     432            0 :         for (unsigned int j = 0; j < run_context.getNumWeightOptVar(i); ++j) {
     433            0 :           run_context.getWeightOptVar(i, j).read(file);
     434              :         }
     435              :       }
     436              :     }
     437              :   } else {
     438            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     439              :       /// @note shared weights are only be read at the first acecss
     440              :       //      if (run_context->isGradientLastAccess(i)) {
     441            0 :       if (run_context.isGradientFirstAccess(i)) {
     442            0 :         if ((mode == ml::train::ExecutionMode::TRAIN) &&
     443            0 :             (definedWeightDataType != TensorDim::DataType::FP32)) {
     444              : 
     445              :           /** @note for batch normalization layer, we do need full
     446              :           precision
     447              :            * for training. but weight can be saved with other type. for
     448              :            * training, bn weight type is fixed with full precsion */
     449              : 
     450            0 :           TensorDim dim = run_context.getWeight(i).getDim();
     451              :           dim.setDataType(definedWeightDataType);
     452            0 :           Tensor T_read(dim, true);
     453            0 :           T_read.read(file);
     454            0 :           run_context.getWeight(i).copyData(T_read);
     455            0 :         } else {
     456            0 :           run_context.getWeight(i).read(file, start_offset);
     457              :         }
     458              : 
     459            0 :         if (run_context.isMixedPrecision(i) && trainable &&
     460            0 :             !run_context.getWeightFP32(i).empty()) {
     461            0 :           run_context.getWeightFP32(i).copyData(run_context.getWeight(i));
     462              :         }
     463              :       }
     464              :     }
     465              :   }
     466            0 : }
     467              : 
     468            0 : void BatchNormalizationLayer::read(ReadSource src, RunLayerContext &run_context,
     469              :                                    bool opt_var, ml::train::ExecutionMode mode,
     470              :                                    bool trainable,
     471              :                                    TensorDim::DataType definedWeightDataType,
     472              :                                    bool fsu, size_t start_offset,
     473              :                                    bool read_from_offset) {
     474            0 :   if (opt_var) {
     475            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     476            0 :       if (run_context.isGradientLastAccess(i) && trainable) {
     477              :         /// @note read optimizer variables
     478            0 :         for (unsigned int j = 0; j < run_context.getNumWeightOptVar(i); ++j) {
     479            0 :           run_context.getWeightOptVar(i, j).read(src);
     480              :         }
     481              :       }
     482              :     }
     483              :   } else {
     484            0 :     for (unsigned int i = 0; i < run_context.getNumWeights(); ++i) {
     485              :       /// @note shared weights are only be read at the first acecss
     486              :       //      if (run_context->isGradientLastAccess(i)) {
     487            0 :       if (run_context.isGradientFirstAccess(i)) {
     488            0 :         if ((mode == ml::train::ExecutionMode::TRAIN) &&
     489            0 :             (definedWeightDataType != TensorDim::DataType::FP32)) {
     490              : 
     491              :           /** @note for batch normalization layer, we do need full
     492              :           precision
     493              :            * for training. but weight can be saved with other type. for
     494              :            * training, bn weight type is fixed with full precsion */
     495              : 
     496            0 :           TensorDim dim = run_context.getWeight(i).getDim();
     497              :           dim.setDataType(definedWeightDataType);
     498            0 :           Tensor T_read(dim, true);
     499            0 :           T_read.read(src);
     500            0 :           run_context.getWeight(i).copyData(T_read);
     501            0 :         } else {
     502            0 :           run_context.getWeight(i).read(src, start_offset);
     503              :         }
     504              : 
     505            0 :         if (run_context.isMixedPrecision(i) && trainable &&
     506            0 :             !run_context.getWeightFP32(i).empty()) {
     507            0 :           run_context.getWeightFP32(i).copyData(run_context.getWeight(i));
     508              :         }
     509              :       }
     510              :     }
     511              :   }
     512            0 : }
     513              : 
     514              : } /* namespace nntrainer */
        

Generated by: LCOV version 2.0-1