LCOV - code coverage report
Current view: top level - nntrainer - context.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 27.3 % 22 6
Test Date: 2025-12-14 20:38:17 Functions: 18.2 % 11 2

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2024 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file    context.h
       6              :  * @date    10 Dec 2024
       7              :  * @see     https://github.com/nnstreamer/nntrainer
       8              :  * @author  Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @bug     No known bugs except for NYI items
      10              :  * @brief   This file contains app context related functions and classes that
      11              :  * manages the global configuration of the current environment.
      12              :  */
      13              : 
      14              : #ifndef __CONTEXT_H__
      15              : #define __CONTEXT_H__
      16              : 
      17              : #include <algorithm>
      18              : #include <functional>
      19              : #include <memory>
      20              : #include <mutex>
      21              : #include <sstream>
      22              : #include <stdexcept>
      23              : #include <string>
      24              : #include <type_traits>
      25              : #include <unordered_map>
      26              : #include <vector>
      27              : 
      28              : #include <context.h>
      29              : #include <layer.h>
      30              : #include <layer_devel.h>
      31              : #include <mem_allocator.h>
      32              : #include <optimizer.h>
      33              : #include <optimizer_devel.h>
      34              : 
      35              : #include <nntrainer_log.h>
      36              : 
      37              : namespace nntrainer {
      38              : 
      39              : /**
      40              :  * @class ContextData contains Data which generated by Context to use in Layers
      41              :  * created by Context
      42              :  * @brief Container to hold context data generated by Context. Eventaully
      43              :  * RunLayerContext in layer_node will hold the ContextData, so that, Layer can
      44              :  * access this Context Data.
      45              :  */
      46              : class ContextData {
      47              : public:
      48           57 :   ContextData() = default;
      49           57 :   virtual ~ContextData() = default;
      50              : 
      51              :   std::shared_ptr<MemAllocator> getMemAllocator() { return mem_allocator; }
      52              : 
      53              :   void setMemAllocator(std::shared_ptr<MemAllocator> m) { mem_allocator = m; }
      54              : 
      55              : private:
      56              :   std::shared_ptr<MemAllocator> mem_allocator = nullptr;
      57              : };
      58              : 
      59              : /**
      60              :  * @class Context contains user-dependent configuration for  support
      61              :  * @brief  support for app context
      62              :  */
      63              : 
      64              : class Context {
      65              : public:
      66              :   using PropsType = std::vector<std::string>;
      67              : 
      68              :   template <typename T> using PtrType = std::unique_ptr<T>;
      69              : 
      70              :   template <typename T>
      71              :   using FactoryType = std::function<PtrType<T>(const PropsType &)>;
      72              : 
      73              :   template <typename T>
      74              :   using PtrFactoryType = PtrType<T> (*)(const PropsType &);
      75              : 
      76              :   template <typename T>
      77              :   using StrIndexType = std::unordered_map<std::string, FactoryType<T>>;
      78              : 
      79              :   /** integer to string key */
      80              :   using IntIndexType = std::unordered_map<int, std::string>;
      81              : 
      82              :   /**
      83              :    * This type contains tuple of
      84              :    * 1) integer -> string index
      85              :    * 2) string -> factory index
      86              :    */
      87              :   template <typename T>
      88              :   using IndexType = std::tuple<StrIndexType<T>, IntIndexType>;
      89              : 
      90              :   template <typename... Ts> using FactoryMap = std::tuple<IndexType<Ts>...>;
      91              : 
      92              :   /**
      93              :    * @brief   Default constructor
      94              :    */
      95           57 :   Context(std::shared_ptr<ContextData> data_ = nullptr) : data(data_) {}
      96              : 
      97              :   /**
      98              :    * @brief   Destructor
      99              :    */
     100           57 :   virtual ~Context() = default;
     101              : 
     102              :   /**
     103              :    *
     104              :    * @brief Initialization of Context.
     105              :    *
     106              :    * @return status &
     107              :    */
     108            0 :   virtual int init() { return 0; };
     109              : 
     110              :   /**
     111              :    * @brief Create an Layer Object from the type (string)
     112              :    *
     113              :    * @param type type of layer
     114              :    * @param props property
     115              :    * @return PtrType<nntrainer::Layer> unique pointer to the object
     116              :    */
     117              :   virtual PtrType<nntrainer::Layer>
     118            0 :   createLayerObject(const std::string &type,
     119              :                     const std::vector<std::string> &props = {}) {
     120            0 :     ml_logw(
     121              :       "[Warning] Implement createLayerObject for the concrete context class to "
     122              :       "properly create the layer");
     123            0 :     return nullptr;
     124              :   };
     125              : 
     126              :   /**
     127              :    * @brief Create an Layer Object from the integer key
     128              :    *
     129              :    * @param int_key integer key
     130              :    * @param props property
     131              :    * @return PtrType<nntrainer::Layer> unique pointer to the object
     132              :    */
     133              :   virtual PtrType<nntrainer::Layer>
     134            0 :   createLayerObject(const int int_key,
     135              :                     const std::vector<std::string> &props = {}) {
     136            0 :     ml_logw(
     137              :       "[Warning] Implement createLayerObject for the concrete context class to "
     138              :       "properly create the layer");
     139            0 :     return nullptr;
     140              :   };
     141              : 
     142              :   /**
     143              :    * @brief Create an Optimizer Object from the type (string)
     144              :    *
     145              :    * @param type type of optimizer
     146              :    * @param props property
     147              :    * @return PtrType<nntrainer::Optimizer> unique pointer to the object
     148              :    */
     149              :   virtual PtrType<nntrainer::Optimizer>
     150            0 :   createOptimizerObject(const std::string &type,
     151              :                         const std::vector<std::string> &props = {}) {
     152            0 :     return nullptr;
     153              :   };
     154              : 
     155              :   /**
     156              :    * @brief Create an Layer Object from the integer key
     157              :    *
     158              :    * @param int_key integer key
     159              :    * @param props property
     160              :    * @return PtrType<nntrainer::Optimizer> unique pointer to the object
     161              :    */
     162              :   virtual PtrType<nntrainer::Optimizer>
     163            0 :   createOptimizerObject(const int int_key,
     164              :                         const std::vector<std::string> &properties = {}) {
     165            0 :     return nullptr;
     166              :   };
     167              : 
     168              :   /**
     169              :    * @brief Create an LearningRateScheduler Object from the type (stirng)
     170              :    *
     171              :    * @param type type of optimizer
     172              :    * @param props property
     173              :    * @return PtrType<ml::train::LearningRateScheduler> unique pointer to the
     174              :    * object
     175              :    */
     176              :   virtual PtrType<ml::train::LearningRateScheduler>
     177            0 :   createLearningRateSchedulerObject(
     178              :     const std::string &type, const std::vector<std::string> &propeties = {}) {
     179            0 :     return nullptr;
     180              :   }
     181              : 
     182              :   /**
     183              :    * @brief Create an LearningRateScheduler Object from the integer key
     184              :    *
     185              :    * @param int_key integer key
     186              :    * @param props property
     187              :    * @return PtrType<ml::train::LearningRateScheduler> unique pointer to the
     188              :    * object
     189              :    */
     190              :   virtual std::unique_ptr<ml::train::LearningRateScheduler>
     191            0 :   createLearningRateSchedulerObject(
     192              :     const int int_key, const std::vector<std::string> &propeties = {}) {
     193            0 :     return nullptr;
     194              :   }
     195              : 
     196              :   /**
     197              :    * @brief getter of context name
     198              :    *
     199              :    * @return string name of the context
     200              :    */
     201              :   virtual std::string getName() = 0;
     202              : 
     203              :   std::shared_ptr<ContextData> getContextData() { return data; }
     204              : 
     205           27 :   std::shared_ptr<MemAllocator> getMemAllocator() {
     206           27 :     return getContextData()->getMemAllocator();
     207              :   };
     208              : 
     209              :   /**
     210              :    * @brief load weight and graph for the specific context
     211              :    *
     212              :    * @return return 0 for success
     213              :    */
     214            0 :   virtual int load(const std::string &file_path) { return 0; };
     215              : 
     216              : private:
     217              :   /**
     218              :    * @brief map of context
     219              :    */
     220              :   static inline std::unordered_map<std::string, Context *> ContextMap;
     221              : 
     222              :   std::shared_ptr<ContextData> data = nullptr;
     223              : };
     224              : 
     225              : using CreateContextFunc = nntrainer::Context *(*)();
     226              : using DestroyContextFunc = void (*)(nntrainer::Context *);
     227              : 
     228              : /**
     229              :  * @brief  Context Pluggable struct that enables pluggable layer
     230              :  *
     231              :  */
     232              : typedef struct {
     233              :   CreateContextFunc createfunc;   /**< create layer function */
     234              :   DestroyContextFunc destroyfunc; /**< destory function */
     235              : } ContextPluggable;
     236              : 
     237              : /**
     238              :  * @brief pluggable Context must have this structure defined
     239              :  */
     240              : extern "C" ContextPluggable ml_train_context_pluggable;
     241              : 
     242              : } // namespace nntrainer
     243              : 
     244              : #endif /* __CONTEXT_H__ */
        

Generated by: LCOV version 2.0-1