mlpack  3.4.2
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
21 #include "all_dimension_select.hpp"
22 #include <type_traits>
23 
24 namespace mlpack {
25 namespace tree {
26 
34 template<typename FitnessFunction = GiniGain,
35  template<typename> class NumericSplitType = BestBinaryNumericSplit,
36  template<typename> class CategoricalSplitType = AllCategoricalSplit,
37  typename DimensionSelectionType = AllDimensionSelect,
38  typename ElemType = double,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::template
42  AuxiliarySplitInfo<ElemType>,
43  public CategoricalSplitType<FitnessFunction>::template
44  AuxiliarySplitInfo<ElemType>
45 {
46  public:
48  typedef NumericSplitType<FitnessFunction> NumericSplit;
50  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
52  typedef DimensionSelectionType DimensionSelection;
53 
71  template<typename MatType, typename LabelsType>
72  DecisionTree(MatType data,
73  const data::DatasetInfo& datasetInfo,
74  LabelsType labels,
75  const size_t numClasses,
76  const size_t minimumLeafSize = 10,
77  const double minimumGainSplit = 1e-7,
78  const size_t maximumDepth = 0,
79  DimensionSelectionType dimensionSelector =
80  DimensionSelectionType());
81 
98  template<typename MatType, typename LabelsType>
99  DecisionTree(MatType data,
100  LabelsType labels,
101  const size_t numClasses,
102  const size_t minimumLeafSize = 10,
103  const double minimumGainSplit = 1e-7,
104  const size_t maximumDepth = 0,
105  DimensionSelectionType dimensionSelector =
106  DimensionSelectionType());
107 
127  template<typename MatType, typename LabelsType, typename WeightsType>
128  DecisionTree(
129  MatType data,
130  const data::DatasetInfo& datasetInfo,
131  LabelsType labels,
132  const size_t numClasses,
133  WeightsType weights,
134  const size_t minimumLeafSize = 10,
135  const double minimumGainSplit = 1e-7,
136  const size_t maximumDepth = 0,
137  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
138  const std::enable_if_t<arma::is_arma_type<
139  typename std::remove_reference<WeightsType>::type>::value>* = 0);
140 
159  template<typename MatType, typename LabelsType, typename WeightsType>
160  DecisionTree(
161  const DecisionTree& other,
162  MatType data,
163  const data::DatasetInfo& datasetInfo,
164  LabelsType labels,
165  const size_t numClasses,
166  WeightsType weights,
167  const size_t minimumLeafSize = 10,
168  const double minimumGainSplit = 1e-7,
169  const std::enable_if_t<arma::is_arma_type<
170  typename std::remove_reference<WeightsType>::type>::value>* = 0);
189  template<typename MatType, typename LabelsType, typename WeightsType>
190  DecisionTree(
191  MatType data,
192  LabelsType labels,
193  const size_t numClasses,
194  WeightsType weights,
195  const size_t minimumLeafSize = 10,
196  const double minimumGainSplit = 1e-7,
197  const size_t maximumDepth = 0,
198  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199  const std::enable_if_t<arma::is_arma_type<
200  typename std::remove_reference<WeightsType>::type>::value>* = 0);
201 
220  template<typename MatType, typename LabelsType, typename WeightsType>
221  DecisionTree(
222  const DecisionTree& other,
223  MatType data,
224  LabelsType labels,
225  const size_t numClasses,
226  WeightsType weights,
227  const size_t minimumLeafSize = 10,
228  const double minimumGainSplit = 1e-7,
229  const size_t maximumDepth = 0,
230  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231  const std::enable_if_t<arma::is_arma_type<
232  typename std::remove_reference<WeightsType>::type>::value>* = 0);
233 
240  DecisionTree(const size_t numClasses = 1);
241 
248  DecisionTree(const DecisionTree& other);
249 
255  DecisionTree(DecisionTree&& other);
256 
263  DecisionTree& operator=(const DecisionTree& other);
264 
271 
275  ~DecisionTree();
276 
296  template<typename MatType, typename LabelsType>
297  double Train(MatType data,
298  const data::DatasetInfo& datasetInfo,
299  LabelsType labels,
300  const size_t numClasses,
301  const size_t minimumLeafSize = 10,
302  const double minimumGainSplit = 1e-7,
303  const size_t maximumDepth = 0,
304  DimensionSelectionType dimensionSelector =
305  DimensionSelectionType());
306 
324  template<typename MatType, typename LabelsType>
325  double Train(MatType data,
326  LabelsType labels,
327  const size_t numClasses,
328  const size_t minimumLeafSize = 10,
329  const double minimumGainSplit = 1e-7,
330  const size_t maximumDepth = 0,
331  DimensionSelectionType dimensionSelector =
332  DimensionSelectionType());
333 
355  template<typename MatType, typename LabelsType, typename WeightsType>
356  double Train(MatType data,
357  const data::DatasetInfo& datasetInfo,
358  LabelsType labels,
359  const size_t numClasses,
360  WeightsType weights,
361  const size_t minimumLeafSize = 10,
362  const double minimumGainSplit = 1e-7,
363  const size_t maximumDepth = 0,
364  DimensionSelectionType dimensionSelector =
365  DimensionSelectionType(),
366  const std::enable_if_t<arma::is_arma_type<typename
367  std::remove_reference<WeightsType>::type>::value>* = 0);
368 
388  template<typename MatType, typename LabelsType, typename WeightsType>
389  double Train(MatType data,
390  LabelsType labels,
391  const size_t numClasses,
392  WeightsType weights,
393  const size_t minimumLeafSize = 10,
394  const double minimumGainSplit = 1e-7,
395  const size_t maximumDepth = 0,
396  DimensionSelectionType dimensionSelector =
397  DimensionSelectionType(),
398  const std::enable_if_t<arma::is_arma_type<typename
399  std::remove_reference<WeightsType>::type>::value>* = 0);
400 
407  template<typename VecType>
408  size_t Classify(const VecType& point) const;
409 
419  template<typename VecType>
420  void Classify(const VecType& point,
421  size_t& prediction,
422  arma::vec& probabilities) const;
423 
431  template<typename MatType>
432  void Classify(const MatType& data,
433  arma::Row<size_t>& predictions) const;
434 
445  template<typename MatType>
446  void Classify(const MatType& data,
447  arma::Row<size_t>& predictions,
448  arma::mat& probabilities) const;
449 
453  template<typename Archive>
454  void serialize(Archive& ar, const unsigned int /* version */);
455 
457  size_t NumChildren() const { return children.size(); }
458 
460  const DecisionTree& Child(const size_t i) const { return *children[i]; }
462  DecisionTree& Child(const size_t i) { return *children[i]; }
463 
466  size_t SplitDimension() const { return splitDimension; }
467 
475  template<typename VecType>
476  size_t CalculateDirection(const VecType& point) const;
477 
481  size_t NumClasses() const;
482 
483  private:
485  std::vector<DecisionTree*> children;
487  size_t splitDimension;
490  size_t dimensionTypeOrMajorityClass;
498  arma::vec classProbabilities;
499 
503  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
504  NumericAuxiliarySplitInfo;
505  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
506  CategoricalAuxiliarySplitInfo;
507 
511  template<bool UseWeights, typename RowType, typename WeightsRowType>
512  void CalculateClassProbabilities(const RowType& labels,
513  const size_t numClasses,
514  const WeightsRowType& weights);
515 
533  template<bool UseWeights, typename MatType>
534  double Train(MatType& data,
535  const size_t begin,
536  const size_t count,
537  const data::DatasetInfo& datasetInfo,
538  arma::Row<size_t>& labels,
539  const size_t numClasses,
540  arma::rowvec& weights,
541  const size_t minimumLeafSize,
542  const double minimumGainSplit,
543  const size_t maximumDepth,
544  DimensionSelectionType& dimensionSelector);
545 
562  template<bool UseWeights, typename MatType>
563  double Train(MatType& data,
564  const size_t begin,
565  const size_t count,
566  arma::Row<size_t>& labels,
567  const size_t numClasses,
568  arma::rowvec& weights,
569  const size_t minimumLeafSize,
570  const double minimumGainSplit,
571  const size_t maximumDepth,
572  DimensionSelectionType& dimensionSelector);
573 };
574 
578 template<typename FitnessFunction = GiniGain,
579  template<typename> class NumericSplitType = BestBinaryNumericSplit,
580  template<typename> class CategoricalSplitType = AllCategoricalSplit,
581  typename DimensionSelectType = AllDimensionSelect,
582  typename ElemType = double>
583 using DecisionStump = DecisionTree<FitnessFunction,
584  NumericSplitType,
585  CategoricalSplitType,
586  DimensionSelectType,
587  ElemType,
588  false>;
589 
598  double,
600 } // namespace tree
601 } // namespace mlpack
602 
603 // Include implementation.
604 #include "decision_tree_impl.hpp"
605 
606 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
size_t NumChildren() const
Get the number of children.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectType, ElemType, false > DecisionStump
Convenience typedef for decision stumps (single level decision trees).
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).