Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file func_data_producer.cpp
6 : * @date 12 July 2021
7 : * @brief This file contains various data producers from a callback
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <func_data_producer.h>
15 :
16 : #include <base_properties.h>
17 : #include <nntrainer_error.h>
18 : #include <node_exporter.h>
19 :
20 : namespace nntrainer {
21 :
22 44 : FuncDataProducer::FuncDataProducer(datagen_cb datagen_cb, void *user_data_) :
23 44 : cb(datagen_cb),
24 44 : user_data_prop(new props::PropsUserData(user_data_)) {}
25 :
26 88 : FuncDataProducer::~FuncDataProducer() {}
27 :
28 0 : const std::string FuncDataProducer::getType() const {
29 0 : return FuncDataProducer::type;
30 : }
31 :
32 51 : void FuncDataProducer::setProperty(const std::vector<std::string> &properties) {
33 51 : auto left = loadProperties(properties, std::tie(*user_data_prop));
34 52 : NNTR_THROW_IF(!left.empty(), std::invalid_argument)
35 : << "properties is not empty, size: " << properties.size();
36 51 : }
37 :
38 : DataProducer::Generator
39 1326 : FuncDataProducer::finalize(const std::vector<TensorDim> &input_dims,
40 : const std::vector<TensorDim> &label_dims,
41 : void *user_data) {
42 1327 : NNTR_THROW_IF(!this->cb, std::invalid_argument)
43 : << "given callback is nullptr!";
44 :
45 1325 : auto input_data = std::shared_ptr<float *>(new float *[input_dims.size()],
46 : std::default_delete<float *[]>());
47 1325 : auto label_data = std::shared_ptr<float *>(new float *[label_dims.size()],
48 : std::default_delete<float *[]>());
49 :
50 100081 : return [cb = this->cb, ud = this->user_data_prop->get(), input_data,
51 1325 : label_data](unsigned int idx, std::vector<Tensor> &inputs,
52 : std::vector<Tensor> &labels) -> bool {
53 : float **input_data_raw = input_data.get();
54 : float **label_data_raw = label_data.get();
55 :
56 194874 : for (unsigned int i = 0; i < inputs.size(); ++i) {
57 97443 : *(input_data_raw + i) = inputs[i].getData();
58 : }
59 :
60 194874 : for (unsigned int i = 0; i < labels.size(); ++i) {
61 97443 : *(label_data_raw + i) = labels[i].getData();
62 : }
63 :
64 97431 : bool last = false;
65 97431 : int status = cb(input_data_raw, label_data_raw, &last, ud);
66 97432 : NNTR_THROW_IF(status != ML_ERROR_NONE, std::invalid_argument)
67 : << "[DataProducer] Callback returned error: " << status << '\n';
68 :
69 97430 : return last;
70 2650 : };
71 : }
72 :
73 0 : void FuncDataProducer::exportTo(Exporter &exporter,
74 0 : const ml::train::ExportMethods &method) const {}
75 :
76 : } // namespace nntrainer
|