mlpack  3.4.2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
serialization_catch.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_SERIALIZATION_CATCH_HPP
13 #define MLPACK_TESTS_SERIALIZATION_CATCH_HPP
14 
15 #include <boost/serialization/serialization.hpp>
16 #include <boost/archive/xml_iarchive.hpp>
17 #include <boost/archive/xml_oarchive.hpp>
18 #include <boost/archive/text_iarchive.hpp>
19 #include <boost/archive/text_oarchive.hpp>
20 #include <boost/archive/binary_iarchive.hpp>
21 #include <boost/archive/binary_oarchive.hpp>
22 #include <mlpack/core.hpp>
23 
24 
25 #include "test_catch_tools.hpp"
26 #include "catch.hpp"
27 
28 namespace mlpack {
29 
30 // Test function for loading and saving Armadillo objects.
31 template<typename CubeType,
32  typename IArchiveType,
33  typename OArchiveType>
34 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
35 {
36  // First save it.
37  // Use type_info name to get unique file name for serialization test files.
38  std::string fileName = FilterFileName(typeid(IArchiveType).name());
39  std::ofstream ofs(fileName, std::ios::binary);
40  bool success = true;
41 
42  {
43  OArchiveType o(ofs);
44 
45  try
46  {
47  o << BOOST_SERIALIZATION_NVP(x);
48  }
49  catch (boost::archive::archive_exception& e)
50  {
51  success = false;
52  }
53  }
54 
55  REQUIRE(success == true);
56  ofs.close();
57 
58  // Now load it.
59  arma::Cube<CubeType> orig(x);
60  success = true;
61  std::ifstream ifs(fileName, std::ios::binary);
62 
63  {
64  IArchiveType i(ifs);
65 
66  try
67  {
68  i >> BOOST_SERIALIZATION_NVP(x);
69  }
70  catch (boost::archive::archive_exception& e)
71  {
72  success = false;
73  }
74  }
75  ifs.close();
76 
77  remove(fileName.c_str());
78 
79  REQUIRE(success == true);
80 
81  REQUIRE(x.n_rows == orig.n_rows);
82  REQUIRE(x.n_cols == orig.n_cols);
83  REQUIRE(x.n_elem_slice == orig.n_elem_slice);
84  REQUIRE(x.n_slices == orig.n_slices);
85  REQUIRE(x.n_elem == orig.n_elem);
86 
87  for (size_t slice = 0; slice != x.n_slices; ++slice)
88  {
89  const auto& origSlice = orig.slice(slice);
90  const auto& xSlice = x.slice(slice);
91  for (size_t i = 0; i < x.n_cols; ++i)
92  {
93  for (size_t j = 0; j < x.n_rows; ++j)
94  {
95  if (double(origSlice(j, i)) == 0.0)
96  REQUIRE(double(xSlice(j, i)) == Approx(0.0).margin(1e-8 / 100));
97  else
98  REQUIRE(double(origSlice(j, i)) ==
99  Approx(double(xSlice(j, i))).epsilon(1e-8 / 100));
100  }
101  }
102  }
103 }
104 
105 // Test all serialization strategies.
106 template<typename CubeType>
107 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
108 {
109  TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
110  boost::archive::xml_oarchive>(x);
111  TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
112  boost::archive::text_oarchive>(x);
113  TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
114  boost::archive::binary_oarchive>(x);
115 }
116 
117 // Test function for loading and saving Armadillo objects.
118 template<typename MatType,
119  typename IArchiveType,
120  typename OArchiveType>
121 void TestArmadilloSerialization(MatType& x)
122 {
123  // First save it.
124  std::string fileName = FilterFileName(typeid(IArchiveType).name());
125  std::ofstream ofs(fileName, std::ios::binary);
126  bool success = true;
127 
128  {
129  OArchiveType o(ofs);
130 
131  try
132  {
133  o << BOOST_SERIALIZATION_NVP(x);
134  }
135  catch (boost::archive::archive_exception& e)
136  {
137  success = false;
138  }
139  }
140 
141  REQUIRE(success == true);
142  ofs.close();
143 
144  // Now load it.
145  MatType orig(x);
146  success = true;
147  std::ifstream ifs(fileName, std::ios::binary);
148 
149  {
150  IArchiveType i(ifs);
151 
152  try
153  {
154  i >> BOOST_SERIALIZATION_NVP(x);
155  }
156  catch (boost::archive::archive_exception& e)
157  {
158  success = false;
159  }
160  }
161  ifs.close();
162 
163  remove(fileName.c_str());
164 
165  REQUIRE(success == true);
166 
167  REQUIRE(x.n_rows == orig.n_rows);
168  REQUIRE(x.n_cols == orig.n_cols);
169  REQUIRE(x.n_elem == orig.n_elem);
170 
171  for (size_t i = 0; i < x.n_cols; ++i)
172  for (size_t j = 0; j < x.n_rows; ++j)
173  if (double(orig(j, i)) == 0.0)
174  REQUIRE(double(x(j, i)) == Approx(0.0).margin(1e-8 / 100));
175  else
176  REQUIRE(double(orig(j, i)) ==
177  Approx(double(x(j, i))).epsilon(1e-8 / 100));
178 }
179 
180 // Test all serialization strategies.
181 template<typename MatType>
182 void TestAllArmadilloSerialization(MatType& x)
183 {
184  TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
185  boost::archive::xml_oarchive>(x);
186  TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
187  boost::archive::text_oarchive>(x);
188  TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
189  boost::archive::binary_oarchive>(x);
190 }
191 
192 // Save and load an mlpack object.
193 // The re-loaded copy is placed in 'newT'.
194 template<typename T, typename IArchiveType, typename OArchiveType>
195 void SerializeObject(T& t, T& newT)
196 {
197  std::string fileName = FilterFileName(typeid(T).name());
198  std::ofstream ofs(fileName, std::ios::binary);
199  bool success = true;
200 
201  {
202  OArchiveType o(ofs);
203 
204  try
205  {
206  o << BOOST_SERIALIZATION_NVP(t);
207  }
208  catch (boost::archive::archive_exception& e)
209  {
210  std::cerr << e.what() << std::endl;
211  success = false;
212  }
213  }
214  ofs.close();
215 
216  REQUIRE(success == true);
217 
218  std::ifstream ifs(fileName, std::ios::binary);
219 
220  {
221  IArchiveType i(ifs);
222 
223  try
224  {
225  i >> BOOST_SERIALIZATION_NVP(newT);
226  }
227  catch (boost::archive::archive_exception& e)
228  {
229  std::cout << e.what() << "\n";
230  success = false;
231  }
232  }
233  ifs.close();
234 
235  remove(fileName.c_str());
236 
237  REQUIRE(success == true);
238 }
239 
240 // Test mlpack serialization with all three archive types.
241 template<typename T>
242 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
243 {
244  SerializeObject<T, boost::archive::xml_iarchive,
245  boost::archive::xml_oarchive>(t, xmlT);
246  SerializeObject<T, boost::archive::text_iarchive,
247  boost::archive::text_oarchive>(t, textT);
248  SerializeObject<T, boost::archive::binary_iarchive,
249  boost::archive::binary_oarchive>(t, binaryT);
250 }
251 
252 // Save and load a non-default-constructible mlpack object.
253 template<typename T, typename IArchiveType, typename OArchiveType>
254 void SerializePointerObject(T* t, T*& newT)
255 {
256  std::string fileName = FilterFileName(typeid(T).name());
257  std::ofstream ofs(fileName, std::ios::binary);
258  bool success = true;
259 
260  {
261  OArchiveType o(ofs);
262  try
263  {
264  o << BOOST_SERIALIZATION_NVP(t);
265  }
266  catch (boost::archive::archive_exception& e)
267  {
268  std::cout << e.what() << "\n";
269  success = false;
270  }
271  }
272  ofs.close();
273 
274  REQUIRE(success == true);
275 
276  std::ifstream ifs(fileName, std::ios::binary);
277 
278  {
279  IArchiveType i(ifs);
280 
281  try
282  {
283  i >> BOOST_SERIALIZATION_NVP(newT);
284  }
285  catch (std::exception& e)
286  {
287  std::cout << e.what() << "\n";
288  success = false;
289  }
290  }
291  ifs.close();
292 
293  remove(fileName.c_str());
294 
295  REQUIRE(success == true);
296 }
297 
298 template<typename T>
299 void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
300 {
301  SerializePointerObject<T, boost::archive::text_iarchive,
302  boost::archive::text_oarchive>(t, textT);
303  SerializePointerObject<T, boost::archive::binary_iarchive,
304  boost::archive::binary_oarchive>(t, binaryT);
305  SerializePointerObject<T, boost::archive::xml_iarchive,
306  boost::archive::xml_oarchive>(t, xmlT);
307 }
308 
309 // Utility function to check the equality of two Armadillo matrices.
310 void CheckMatrices(const arma::mat& x,
311  const arma::mat& xmlX,
312  const arma::mat& textX,
313  const arma::mat& binaryX);
314 
315 void CheckMatrices(const arma::Mat<size_t>& x,
316  const arma::Mat<size_t>& xmlX,
317  const arma::Mat<size_t>& textX,
318  const arma::Mat<size_t>& binaryX);
319 
320 void CheckMatrices(const arma::cube& x,
321  const arma::cube& xmlX,
322  const arma::cube& textX,
323  const arma::cube& binaryX);
324 
325 } // namespace mlpack
326 
327 #endif
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
std::string FilterFileName(const std::string &inputString)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void SerializeObject(T &t, T &newT)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObject(T *t, T *&newT)
src mlpack core util version hpp VERSION_HPP_CONTENTS string(REGEX REPLACE".*#define MLPACK_VERSION_MAJOR ([0-9]+).*""\\1"MLPACK_VERSION_MAJOR"${VERSION_HPP_CONTENTS}") string(REGEX REPLACE".* MLPACK_VERSION_MINOR "$
Definition: CMakeLists.txt:79