LCOV - code coverage report
Current view: top level - nntrainer - app_context.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 97.2 % 36 35
Test Date: 2025-12-14 20:38:17 Functions: 95.5 % 22 21

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file   app_context.h
       6              :  * @date   10 November 2020
       7              :  * @brief  This file contains app context related functions and classes that
       8              :  * manages the global configuration of the current environment
       9              :  * @see    https://github.com/nnstreamer/nntrainer
      10              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      11              :  * @bug    No known bugs except for NYI items
      12              :  *
      13              :  */
      14              : 
      15              : #ifndef __APP_CONTEXT_H__
      16              : #define __APP_CONTEXT_H__
      17              : 
      18              : #include <algorithm>
      19              : #include <functional>
      20              : #include <memory>
      21              : #include <mutex>
      22              : #include <sstream>
      23              : #include <stdexcept>
      24              : #include <string>
      25              : #include <type_traits>
      26              : #include <unordered_map>
      27              : #include <vector>
      28              : 
      29              : #include <layer.h>
      30              : #include <layer_devel.h>
      31              : #include <optimizer.h>
      32              : #include <optimizer_devel.h>
      33              : 
      34              : #include <context.h>
      35              : #include <mem_allocator.h>
      36              : #include <nntrainer_error.h>
      37              : 
      38              : #include "singleton.h"
      39              : 
      40              : namespace nntrainer {
      41              : 
      42              : extern std::mutex factory_mutex;
      43              : namespace {} // namespace
      44              : 
      45              : /**
      46              :  * @class AppContext contains user-dependent configuration
      47              :  * @brief App
      48              :  */
      49              : class AppContext : public Context, public Singleton<AppContext> {
      50              : public:
      51              :   /**
      52              :    * @brief   Default constructor
      53              :    */
      54          114 :   AppContext() : Context(std::make_shared<ContextData>()) {}
      55              : 
      56              :   /**
      57              :    * @brief   Default destructor
      58              :    */
      59           77 :   ~AppContext() override = default;
      60              : 
      61              :   /**
      62              :    * @brief Set Working Directory for a relative path. working directory is set
      63              :    * canonically
      64              :    * @param[in] base base directory
      65              :    * @throw std::invalid_argument if path is not valid for current system
      66              :    */
      67              :   void setWorkingDirectory(const std::string &base);
      68              : 
      69              :   /**
      70              :    * @brief unset working directory
      71              :    *
      72              :    */
      73              :   void unsetWorkingDirectory() { working_path_base = ""; }
      74              : 
      75              :   /**
      76              :    * @brief query if the appcontext has working directory set
      77              :    *
      78              :    * @retval true working path base is set
      79              :    * @retval false working path base is not set
      80              :    */
      81              :   bool hasWorkingDirectory() { return !working_path_base.empty(); }
      82              : 
      83              :   /**
      84              :    * @brief register a layer factory from a shared library
      85              :    * plugin must have **extern "C" LayerPluggable *ml_train_layer_pluggable**
      86              :    * defined else error
      87              :    *
      88              :    * @param library_path a file name of the library
      89              :    * @param base_path    base path to make a full path (optional)
      90              :    * @return int integer key to create the layer
      91              :    * @throws std::invalid_parameter if library_path is invalid or library is
      92              :    * invalid
      93              :    */
      94              :   int registerLayer(const std::string &library_path,
      95              :                     const std::string &base_path = "");
      96              : 
      97              :   /**
      98              :    * @brief register a optimizer factory from a shared library
      99              :    * plugin must have **extern "C" OptimizerPluggable
     100              :    * *ml_train_optimizer_pluggable** defined else error
     101              :    *
     102              :    * @param library_path a file name of the library
     103              :    * @param base_path    base path to make a full path (optional)
     104              :    * @return int integer key to create the optimizer
     105              :    * @throws std::invalid_parameter if library_path is invalid or library is
     106              :    * invalid
     107              :    */
     108              :   int registerOptimizer(const std::string &library_path,
     109              :                         const std::string &base_path = "");
     110              : 
     111              :   /**
     112              :    * @brief register pluggables from a directory.
     113              :    * @note if you have a clashing type with already registered pluggable, it
     114              :    * will throw from `registerFactory` function
     115              :    *
     116              :    * @param base_path a directory path to search pluggables's
     117              :    * @return std::vector<int> list of integer key to create a pluggable
     118              :    */
     119              :   std::vector<int> registerPluggableFromDirectory(const std::string &base_path);
     120              : 
     121              :   /**
     122              :    * @brief Get Working Path from a relative or representation of a path
     123              :    * starting from @a working_path_base.
     124              :    * @param[in] path to make full path
     125              :    * @return If absolute path is given, returns @a path
     126              :    * If relative path is given and working_path_base is not set, return
     127              :    * relative path.
     128              :    * If relative path is given and working_path_base has set, return absolute
     129              :    * path from current working directory
     130              :    */
     131              :   const std::string getWorkingPath(const std::string &path = "");
     132              : 
     133              :   /**
     134              :    * @brief Get memory fsu file path from configuration file
     135              :    * @return memory fsu path.
     136              :    * If memory fsu path is not presented in configuration file, it returns
     137              :    * empty string
     138              :    */
     139              :   const std::vector<std::string> getProperties(void);
     140              : 
     141              :   /**
     142              :    * @brief Factory register function, use this function to register custom
     143              :    * object
     144              :    *
     145              :    * @tparam T object to create. Currently Optimizer, Layer is supported
     146              :    * @param factory factory function that creates std::unique_ptr<T>
     147              :    * @param key key to access the factory, if key is empty, try to find key by
     148              :    * calling factory({})->getType();
     149              :    * @param int_key key to access the factory by integer, if it is -1(default),
     150              :    * the function automatically unsigned the key and return
     151              :    * @return const int unique integer value to access the current factory
     152              :    * @throw invalid argument when key and/or int_key is already taken
     153              :    */
     154              :   template <typename T>
     155         1881 :   const int registerFactory(const PtrFactoryType<T> factory,
     156              :                             const std::string &key = "",
     157              :                             const int int_key = -1) {
     158              :     FactoryType<T> f = factory;
     159         5640 :     return registerFactory(f, key, int_key);
     160              :   }
     161              : 
     162              :   /**
     163              :    * @brief Factory register function, use this function to register custom
     164              :    * object
     165              :    *
     166              :    * @tparam T object to create. Currently Optimizer, Layer is supported
     167              :    * @param factory factory function that creates std::unique_ptr<T>
     168              :    * @param key key to access the factory, if key is empty, try to find key by
     169              :    * calling factory({})->getType();
     170              :    * @param int_key key to access the factory by integer, if it is -1(default),
     171              :    * the function automatically unsigned the key and return
     172              :    * @return const int unique integer value to access the current factory
     173              :    * @throw invalid argument when key and/or int_key is already taken
     174              :    */
     175              :   template <typename T>
     176              :   const int registerFactory(const FactoryType<T> factory,
     177              :                             const std::string &key = "",
     178              :                             const int int_key = -1);
     179              : 
     180              :   std::unique_ptr<nntrainer::Layer>
     181         6126 :   createLayerObject(const std::string &type,
     182              :                     const std::vector<std::string> &properties = {}) override {
     183         6126 :     return createObject<nntrainer::Layer>(type, properties);
     184              :   }
     185              : 
     186          755 :   std::unique_ptr<nntrainer::Optimizer> createOptimizerObject(
     187              :     const std::string &type,
     188              :     const std::vector<std::string> &properties = {}) override {
     189          755 :     return createObject<nntrainer::Optimizer>(type, properties);
     190              :   }
     191              : 
     192              :   std::unique_ptr<ml::train::LearningRateScheduler>
     193           54 :   createLearningRateSchedulerObject(
     194              :     const std::string &type,
     195              :     const std::vector<std::string> &properties = {}) override {
     196           54 :     return createObject<ml::train::LearningRateScheduler>(type, properties);
     197              :   }
     198              : 
     199              :   std::unique_ptr<nntrainer::Layer>
     200           91 :   createLayerObject(const int int_key,
     201              :                     const std::vector<std::string> &properties = {}) override {
     202           91 :     return createObject<nntrainer::Layer>(int_key, properties);
     203              :   }
     204              : 
     205           29 :   std::unique_ptr<nntrainer::Optimizer> createOptimizerObject(
     206              :     const int int_key,
     207              :     const std::vector<std::string> &properties = {}) override {
     208           29 :     return createObject<nntrainer::Optimizer>(int_key, properties);
     209              :   }
     210              : 
     211              :   std::unique_ptr<ml::train::LearningRateScheduler>
     212           22 :   createLearningRateSchedulerObject(
     213              :     const int int_key,
     214              :     const std::vector<std::string> &properties = {}) override {
     215           22 :     return createObject<ml::train::LearningRateScheduler>(int_key, properties);
     216              :   }
     217              : 
     218              :   /**
     219              :    * @brief Create an Object from the integer key
     220              :    *
     221              :    * @tparam T Type of Object, currently, Only optimizer is supported
     222              :    * @param int_key integer key
     223              :    * @param props property
     224              :    * @return PtrType<T> unique pointer to the object
     225              :    */
     226              :   template <typename T>
     227          147 :   PtrType<T> createObject(const int int_key,
     228              :                           const PropsType &props = {}) const {
     229              :     static_assert(isSupported<T>::value,
     230              :                   "given type is not supported for current app context");
     231              :     auto &index = std::get<IndexType<T>>(factory_map);
     232              :     auto &int_map = std::get<IntIndexType>(index);
     233              : 
     234              :     const auto &entry = int_map.find(int_key);
     235              : 
     236          147 :     if (entry == int_map.end()) {
     237            2 :       std::stringstream ss;
     238            2 :       ss << "Int Key is not found for the object. Key: " << int_key;
     239            4 :       throw exception::not_supported(ss.str().c_str());
     240            2 :     }
     241              : 
     242          145 :     return createObject<T>(entry->second, props);
     243              :   }
     244              : 
     245              :   /**
     246              :    * @brief Create an Object from the string key
     247              :    *
     248              :    * @tparam T Type of object, currently, only optimizer is supported
     249              :    * @param key integer key
     250              :    * @param props property
     251              :    * @return PtrType<T> unique pointer to the object
     252              :    */
     253              :   template <typename T>
     254         7172 :   PtrType<T> createObject(const std::string &key,
     255              :                           const PropsType &props = {}) const {
     256              :     auto &index = std::get<IndexType<T>>(factory_map);
     257              :     auto &str_map = std::get<StrIndexType<T>>(index);
     258              : 
     259              :     std::string lower_key;
     260              :     lower_key.resize(key.size());
     261              : 
     262              :     std::transform(key.begin(), key.end(), lower_key.begin(),
     263        60326 :                    [](unsigned char c) { return std::tolower(c); });
     264              : 
     265              :     const auto &entry = str_map.find(lower_key);
     266              : 
     267         7172 :     if (entry == str_map.end()) {
     268           36 :       std::stringstream ss;
     269              :       ss << "Key is not found for the object. Key: " << lower_key;
     270           72 :       throw exception::not_supported(ss.str().c_str());
     271           36 :     }
     272              : 
     273         7123 :     return entry->second(props);
     274              :   }
     275              : 
     276              :   /**
     277              :    * @brief special factory that throws for unknown
     278              :    *
     279              :    * @tparam T object to create
     280              :    * @param props props to pass, not used
     281              :    * @throw always throw runtime_error
     282              :    */
     283              :   template <typename T>
     284            5 :   static PtrType<T> unknownFactory(const PropsType &props) {
     285            5 :     throw std::invalid_argument("cannot create unknown object");
     286              :   }
     287              : 
     288            0 :   std::string getName() override { return "cpu"; }
     289              : 
     290           27 :   void setMemAllocator(std::shared_ptr<MemAllocator> mem) {
     291           27 :     getContextData()->setMemAllocator(mem);
     292           27 :   }
     293              : 
     294              : private:
     295              :   /**
     296              :    * @brief   Overriden initialization function
     297              :    */
     298              :   void initialize() noexcept override;
     299              : 
     300              :   void add_default_object();
     301              : 
     302              :   void add_extension_object();
     303              : 
     304              :   FactoryMap<nntrainer::Optimizer, nntrainer::Layer,
     305              :              ml::train::LearningRateScheduler>
     306              :     factory_map;
     307              :   std::string working_path_base;
     308              : 
     309              :   template <typename Args, typename T> struct isSupportedHelper;
     310              : 
     311              :   /**
     312              :    * @brief supportHelper to check if given type is supported within appcontext
     313              :    */
     314              :   template <typename T, typename... Args>
     315              :   struct isSupportedHelper<T, AppContext::FactoryMap<Args...>> {
     316              :     static constexpr bool value =
     317              :       (std::is_same_v<std::decay_t<T>, std::decay_t<Args>> || ...);
     318              :   };
     319              : 
     320              :   /**
     321              :    * @brief supportHelper to check if given type is supported within appcontext
     322              :    */
     323              :   template <typename T>
     324              :   struct isSupported : isSupportedHelper<T, decltype(factory_map)> {};
     325              : };
     326              : 
     327              : /**
     328              :  * @copydoc const int AppContext::registerFactory
     329              :  */
     330              : extern template const int AppContext::registerFactory<nntrainer::Optimizer>(
     331              :   const FactoryType<nntrainer::Optimizer> factory, const std::string &key,
     332              :   const int int_key);
     333              : 
     334              : /**
     335              :  * @copydoc const int AppContext::registerFactory
     336              :  */
     337              : extern template const int AppContext::registerFactory<nntrainer::Layer>(
     338              :   const FactoryType<nntrainer::Layer> factory, const std::string &key,
     339              :   const int int_key);
     340              : 
     341              : /**
     342              :  * @copydoc const int AppContext::registerFactory
     343              :  */
     344              : extern template const int
     345              : AppContext::registerFactory<ml::train::LearningRateScheduler>(
     346              :   const FactoryType<ml::train::LearningRateScheduler> factory,
     347              :   const std::string &key, const int int_key);
     348              : 
     349              : namespace plugin {}
     350              : 
     351              : } // namespace nntrainer
     352              : 
     353              : #endif /* __APP_CONTEXT_H__ */
        

Generated by: LCOV version 2.0-1