mlpack  3.4.2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
cosine_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
13 #define MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <boost/heap/priority_queue.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
21 // Predeclare classes for CosineNodeQueue typedef.
22 class CompareCosineNode;
23 class CosineTree;
24 
25 // CosineNodeQueue typedef.
26 typedef boost::heap::priority_queue<CosineTree*,
27  boost::heap::compare<CompareCosineNode> > CosineNodeQueue;
28 
30 {
31  public:
40  CosineTree(const arma::mat& dataset);
41 
51  CosineTree(CosineTree& parentNode, const std::vector<size_t>& subIndices);
52 
67  CosineTree(const arma::mat& dataset,
68  const double epsilon,
69  const double delta);
70 
76  CosineTree(const CosineTree& other);
77 
84  CosineTree(CosineTree&& other);
85 
91  CosineTree& operator=(const CosineTree& other);
92 
99 
103  ~CosineTree();
104 
114  void ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
115  arma::vec& centroid,
116  arma::vec& newBasisVector,
117  arma::vec* addBasisVector = NULL);
118 
131  double MonteCarloError(CosineTree* node,
132  CosineNodeQueue& treeQueue,
133  arma::vec* addBasisVector1 = NULL,
134  arma::vec* addBasisVector2 = NULL);
135 
141  void ConstructBasis(CosineNodeQueue& treeQueue);
142 
148  void CosineNodeSplit();
149 
156  void ColumnSamplesLS(std::vector<size_t>& sampledIndices,
157  arma::vec& probabilities, size_t numSamples);
158 
165  size_t ColumnSampleLS();
166 
179  size_t BinarySearch(arma::vec& cDistribution, double value, size_t start,
180  size_t end);
181 
189  void CalculateCosines(arma::vec& cosines);
190 
195  void CalculateCentroid();
196 
198  void GetFinalBasis(arma::mat& finalBasis) { finalBasis = basis; }
199 
201  const arma::mat& GetDataset() const { return *dataset; }
202 
204  std::vector<size_t>& VectorIndices() { return indices; }
205 
207  void L2Error(const double error) { this->l2Error = error; }
209  double L2Error() const { return l2Error; }
210 
212  arma::vec& Centroid() { return centroid; }
213 
215  void BasisVector(arma::vec& bVector) { this->basisVector = bVector; }
216 
218  arma::vec& BasisVector() { return basisVector; }
219 
221  CosineTree* Parent() const { return parent; }
223  CosineTree*& Parent() { return parent; }
224 
226  CosineTree* Left() const { return left; }
228  CosineTree*& Left() { return left; }
229 
231  CosineTree* Right() const { return right; }
233  CosineTree*& Right() { return right; }
234 
236  size_t NumColumns() const { return numColumns; }
237 
239  double FrobNormSquared() const { return frobNormSquared; }
240 
242  size_t SplitPointIndex() const { return indices[splitPointIndex]; }
243 
244  private:
246  const arma::mat* dataset;
248  double delta;
250  arma::mat basis;
252  CosineTree* parent;
254  CosineTree* left;
256  CosineTree* right;
258  std::vector<size_t> indices;
260  arma::vec l2NormsSquared;
262  arma::vec centroid;
264  arma::vec basisVector;
266  size_t splitPointIndex;
268  size_t numColumns;
270  double l2Error;
272  double frobNormSquared;
274  bool localDataset;
275 };
276 
278 {
279  public:
280  // Comparison function for construction of priority queue.
281  bool operator() (const CosineTree* a, const CosineTree* b) const
282  {
283  return a->L2Error() < b->L2Error();
284  }
285 };
286 
287 } // namespace tree
288 } // namespace mlpack
289 
290 #endif
bool operator()(const CosineTree *a, const CosineTree *b) const
double FrobNormSquared() const
Get the Frobenius norm squared of columns in the node.
void ModifiedGramSchmidt(CosineNodeQueue &treeQueue, arma::vec &centroid, arma::vec &newBasisVector, arma::vec *addBasisVector=NULL)
Calculates the orthonormalization of the passed centroid, with respect to the current vector subspace...
arma::vec & Centroid()
Get pointer to the centroid vector.
void GetFinalBasis(arma::mat &finalBasis)
Returns the basis of the constructed subspace.
CosineTree *& Left()
Modify the pointer to the left child of the node.
double MonteCarloError(CosineTree *node, CosineNodeQueue &treeQueue, arma::vec *addBasisVector1=NULL, arma::vec *addBasisVector2=NULL)
Estimates the squared error of the projection of the input node&#39;s matrix onto the current vector subs...
void ConstructBasis(CosineNodeQueue &treeQueue)
Constructs the final basis matrix, after the cosine tree construction.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void L2Error(const double error)
Set the Monte Carlo error.
const arma::mat & GetDataset() const
Get pointer to the dataset matrix.
CosineTree * Left() const
Get pointer to the left child of the node.
void ColumnSamplesLS(std::vector< size_t > &sampledIndices, arma::vec &probabilities, size_t numSamples)
Sample &#39;numSamples&#39; points from the Length-Squared distribution of the cosine node.
size_t SplitPointIndex() const
Get the column index of split point of the node.
size_t ColumnSampleLS()
Sample a point from the Length-Squared distribution of the cosine node.
CosineTree(const arma::mat &dataset)
CosineTree constructor for the root node of the tree.
void CosineNodeSplit()
This function splits the cosine node into two children based on the cosines of the columns contained ...
double L2Error() const
Get the Monte Carlo error.
CosineTree *& Right()
Modify the pointer to the left child of the node.
CosineTree *& Parent()
Modify the pointer to the parent node.
std::vector< size_t > & VectorIndices()
Get the indices of columns in the node.
CosineTree * Parent() const
Get pointer to the parent node.
~CosineTree()
Clean up the CosineTree: release allocated memory (including children).
CosineTree & operator=(const CosineTree &other)
Copy the given Cosine Tree.
CosineTree * Right() const
Get pointer to the right child of the node.
size_t BinarySearch(arma::vec &cDistribution, double value, size_t start, size_t end)
Sample a column based on the cumulative Length-Squared distribution of the cosine node...
void BasisVector(arma::vec &bVector)
Set the basis vector of the node.
void CalculateCentroid()
Calculate centroid of the columns present in the node.
size_t NumColumns() const
Get number of columns of input matrix in the node.
arma::vec & BasisVector()
Get the basis vector of the node.
void CalculateCosines(arma::vec &cosines)
Calculate cosines of the columns present in the node, with respect to the sampled splitting point...
boost::heap::priority_queue< CosineTree *, boost::heap::compare< CompareCosineNode > > CosineNodeQueue
Definition: cosine_tree.hpp:23