Line data Source code
1 : /**
2 : * Copyright (C) 2020 Samsung Electronics Co., Ltd. All Rights Reserved.
3 : *
4 : * Licensed under the Apache License, Version 2.0 (the "License");
5 : * you may not use this file except in compliance with the License.
6 : * You may obtain a copy of the License at
7 : * http://www.apache.org/licenses/LICENSE-2.0
8 : * Unless required by applicable law or agreed to in writing, software
9 : * distributed under the License is distributed on an "AS IS" BASIS,
10 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 : * See the License for the specific language governing permissions and
12 : * limitations under the License.
13 : *
14 : * @file util_func.cpp
15 : * @date 08 April 2020
16 : * @brief This is collection of math functions
17 : * @see https://github.com/nnstreamer/nntrainer
18 : * @author Jijoong Moon <jijoong.moon@samsung.com>
19 : * @bug No known bugs except for NYI items
20 : *
21 : */
22 :
23 : #ifdef _WIN32
24 : #define MAX_PATH_LENGTH 1024
25 : #endif
26 :
27 : #include <cmath>
28 : #include <fstream>
29 : #include <random>
30 :
31 : #include <acti_func.h>
32 : #include <nntrainer_log.h>
33 : #include <util_func.h>
34 :
35 : namespace nntrainer {
36 :
37 : static std::uniform_real_distribution<float> dist(-0.5, 0.5);
38 :
39 52680906 : double sqrtDouble(double x) { return sqrt(x); };
40 :
41 0 : float logFloat(float x) { return log(x + 1.0e-20); }
42 :
43 0 : float exp_util(float x) { return exp(x); }
44 :
45 168 : uint32_t ceilDiv(uint32_t a, uint32_t b) { return (a + b - 1) / b; };
46 :
47 56085160 : uint32_t align(uint32_t a, uint32_t b) {
48 56085160 : return (a % b == 0) ? a : a - a % b + b;
49 : };
50 :
51 0 : Tensor rotate_180(Tensor in) {
52 0 : Tensor output(in.getDim());
53 0 : output.setZero();
54 0 : for (unsigned int i = 0; i < in.batch(); ++i) {
55 0 : for (unsigned int j = 0; j < in.channel(); ++j) {
56 0 : for (unsigned int k = 0; k < in.height(); ++k) {
57 0 : for (unsigned int l = 0; l < in.width(); ++l) {
58 0 : output.setValue(
59 : i, j, k, l,
60 0 : in.getValue(i, j, (in.height() - k - 1), (in.width() - l - 1)));
61 : }
62 : }
63 : }
64 : }
65 0 : return output;
66 0 : }
67 :
68 1551 : bool isFileExist(std::string file_name) {
69 1551 : std::ifstream infile(file_name);
70 1551 : return infile.good();
71 1551 : }
72 :
73 : template <typename T>
74 28782 : static void checkFile(const T &file, const char *error_msg) {
75 28782 : if (file.bad() || file.eof() || !file.good() || file.fail()) {
76 4 : throw std::runtime_error(error_msg);
77 : }
78 28778 : }
79 :
80 26643 : void checkedRead(std::ifstream &file, char *array, std::streamsize size,
81 : const char *error_msg, size_t start_offset,
82 : bool read_from_offset) {
83 26643 : if (read_from_offset) {
84 0 : file.seekg(start_offset, std::ios::beg);
85 0 : checkFile(file, "failed to move offset");
86 : }
87 26643 : file.read(array, size);
88 26643 : checkFile(file, error_msg);
89 26641 : }
90 :
91 0 : void checkedRead(ReadSource src, char *array, std::streamsize size,
92 : const char *error_msg, size_t start_offset,
93 : bool read_from_offset) {
94 :
95 : if (auto f = std::get_if<std::ifstream *>(&src)) {
96 0 : if (read_from_offset) {
97 0 : (*f)->seekg(start_offset, std::ios::beg);
98 : }
99 0 : (*f)->read(static_cast<char *>(array), static_cast<std::streamsize>(size));
100 : // checkFile((*f), error_msg);
101 : } else if (auto p = std::get_if<const char *>(&src)) {
102 0 : if (read_from_offset) {
103 0 : std::memcpy(array, (*p) + start_offset, size);
104 : } else {
105 0 : std::memcpy(array, (*p), size);
106 : }
107 : }
108 0 : }
109 :
110 2139 : void checkedWrite(std::ostream &file, const char *array, std::streamsize size,
111 : const char *error_msg) {
112 2139 : file.write(array, size);
113 :
114 2139 : checkFile(file, error_msg);
115 2137 : }
116 :
117 1 : std::string readString(std::ifstream &file, const char *error_msg) {
118 : std::string str;
119 : size_t size;
120 :
121 1 : checkedRead(file, (char *)&size, sizeof(size), error_msg);
122 :
123 0 : std::streamsize sz = static_cast<std::streamsize>(size);
124 1 : NNTR_THROW_IF(sz < 0, std::invalid_argument)
125 : << "read string size: " << sz
126 : << " is too big. It cannot be represented by std::streamsize";
127 :
128 : str.resize(size);
129 0 : checkedRead(file, (char *)&str[0], sz, error_msg);
130 :
131 0 : return str;
132 : }
133 :
134 1 : void writeString(std::ofstream &file, const std::string &str,
135 : const char *error_msg) {
136 1 : size_t size = str.size();
137 :
138 1 : checkedWrite(file, (char *)&size, sizeof(size), error_msg);
139 :
140 0 : std::streamsize sz = static_cast<std::streamsize>(size);
141 0 : NNTR_THROW_IF(sz < 0, std::invalid_argument)
142 : << "write string size: " << size
143 : << " is too big. It cannot be represented by std::streamsize";
144 :
145 0 : checkedWrite(file, (char *)&str[0], sz, error_msg);
146 0 : }
147 :
148 2991 : bool endswith(const std::string &target, const std::string &suffix) {
149 2991 : if (target.size() < suffix.size()) {
150 : return false;
151 : }
152 2991 : size_t spos = target.size() - suffix.size();
153 2991 : return target.substr(spos) == suffix;
154 : }
155 :
156 137321 : int getKeyValue(const std::string &input_str, std::string &key,
157 : std::string &value) {
158 : int status = ML_ERROR_NONE;
159 : auto input_trimmed = input_str;
160 :
161 : std::vector<std::string> list;
162 137321 : static const std::regex words_regex("[^\\s=]+");
163 137321 : input_trimmed.erase(
164 137321 : std::remove(input_trimmed.begin(), input_trimmed.end(), ' '),
165 : input_trimmed.end());
166 : auto words_begin = std::sregex_iterator(input_trimmed.begin(),
167 137321 : input_trimmed.end(), words_regex);
168 137321 : auto words_end = std::sregex_iterator();
169 137321 : int nwords = std::distance(words_begin, words_end);
170 :
171 137321 : if (nwords != 2) {
172 106 : ml_loge("Error: input string must be 'key = value' format "
173 : "(e.g.{\"key1=value1\",\"key2=value2\"}), \"%s\" given",
174 : input_trimmed.c_str());
175 106 : return ML_ERROR_INVALID_PARAMETER;
176 : }
177 :
178 411645 : for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
179 548860 : list.push_back((*i).str());
180 : }
181 :
182 : key = list[0];
183 : value = list[1];
184 :
185 : return status;
186 137321 : }
187 :
188 1118 : int getValues(int n_str, std::string str, int *value) {
189 : int status = ML_ERROR_NONE;
190 1118 : static const std::regex words_regex("[^\\s.,:;!?]+");
191 1118 : str.erase(std::remove(str.begin(), str.end(), ' '), str.end());
192 1118 : auto words_begin = std::sregex_iterator(str.begin(), str.end(), words_regex);
193 1118 : auto words_end = std::sregex_iterator();
194 :
195 1118 : int num = std::distance(words_begin, words_end);
196 1118 : if (num != n_str) {
197 0 : ml_loge("Number of Data is not match");
198 0 : return ML_ERROR_INVALID_PARAMETER;
199 : }
200 : int cn = 0;
201 4472 : for (std::sregex_iterator i = words_begin; i != words_end; ++i) {
202 3354 : value[cn] = std::stoi((*i).str());
203 3354 : cn++;
204 : }
205 1118 : return status;
206 : }
207 :
208 113613 : std::vector<std::string> split(const std::string &s, const std::regex ®) {
209 : std::vector<std::string> out;
210 : const int NUM_SKIP_CHAR = 3;
211 113613 : char char_to_remove[NUM_SKIP_CHAR] = {' ', '[', ']'};
212 : std::string str = s;
213 454452 : for (unsigned int i = 0; i < NUM_SKIP_CHAR; ++i) {
214 340839 : str.erase(std::remove(str.begin(), str.end(), char_to_remove[i]),
215 : str.end());
216 : }
217 :
218 : std::regex_token_iterator<std::string::iterator> end;
219 : std::regex_token_iterator<std::string::iterator> iter(str.begin(), str.end(),
220 113613 : reg, -1);
221 :
222 233366 : while (iter != end) {
223 0 : out.push_back(*iter);
224 119753 : ++iter;
225 : }
226 113613 : return out;
227 0 : }
228 :
229 859543 : bool istrequal(const std::string &a, const std::string &b) {
230 859543 : if (a.size() != b.size())
231 : return false;
232 :
233 192223 : return std::equal(a.begin(), a.end(), b.begin(), [](char a_, char b_) {
234 972447 : return tolower(a_) == tolower(b_);
235 192223 : });
236 : }
237 :
238 1338 : char *getRealpath(const char *name, char *resolved) {
239 : #ifdef _WIN32
240 : return _fullpath(resolved, name, MAX_PATH_LENGTH);
241 : #else
242 : resolved = realpath(name, nullptr);
243 1338 : return resolved;
244 : #endif
245 : }
246 :
247 56764 : tm *getLocaltime(tm *tp) {
248 56764 : time_t t = time(0);
249 : #ifdef _WIN32
250 : localtime_s(tp, &t);
251 : return tp;
252 : #else
253 56764 : return localtime_r(&t, tp);
254 : #endif
255 : }
256 :
257 3086 : std::regex getRegex(const std::string &str) {
258 : std::regex result;
259 :
260 : try {
261 3086 : result = std::regex(str);
262 0 : } catch (const std::regex_error &e) {
263 0 : ml_loge("regex_error caught: %s", e.what());
264 0 : }
265 :
266 3086 : return result;
267 0 : }
268 :
269 0 : void floatToFixedPointAndExponent(float input, int &fixedpoint, int &exponent) {
270 0 : exponent = 0;
271 : // normalize the floating-point number into the form: mantissa * 2^exponent
272 : float mantissa = std::frexp(input, &exponent);
273 : // scale mantissa to a fixed-point range to maximize precision
274 0 : fixedpoint = static_cast<int>(
275 0 : mantissa * static_cast<float>(std::numeric_limits<int>::max()));
276 0 : }
277 :
278 0 : float fixedPointAndExponentToFloat(int fixedpoint, int exponent) {
279 : // scale back to the normalized floating-point range
280 0 : float mantissa = static_cast<float>(fixedpoint) /
281 : static_cast<float>(std::numeric_limits<int>::max());
282 : // reconstruct the floating-point number
283 0 : return std::ldexp(mantissa, exponent);
284 : }
285 :
286 : } // namespace nntrainer
|