37#ifndef VIGRA_RANDOM_FOREST_HXX
38#define VIGRA_RANDOM_FOREST_HXX
46#include "mathutil.hxx"
47#include "array_vector.hxx"
48#include "sized_int.hxx"
50#include "metaprogramming.hxx"
52#include "functorexpression.hxx"
53#include "random_forest/rf_common.hxx"
54#include "random_forest/rf_nodeproxy.hxx"
55#include "random_forest/rf_split.hxx"
56#include "random_forest/rf_decisionTree.hxx"
57#include "random_forest/rf_visitors.hxx"
58#include "random_forest/rf_region.hxx"
59#include "sampling.hxx"
60#include "random_forest/rf_preprocessing.hxx"
61#include "random_forest/rf_online_prediction_set.hxx"
62#include "random_forest/rf_earlystopping.hxx"
63#include "random_forest/rf_ridge_split.hxx"
83inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
86 return_opt.withReplacement(RF_opt.sample_with_replacement_);
87 return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
146template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
153 typedef detail::DecisionTree DecisionTree_t;
155 typedef GiniSplit Default_Split_t;
160 typedef LabelType LabelT;
168 ProblemSpec_t ext_param_;
198 ProblemSpec_t
const &
ext_param = ProblemSpec_t())
227 template<
class TopologyIterator,
class ParameterIterator>
229 TopologyIterator topology_begin,
230 ParameterIterator parameter_begin,
231 ProblemSpec_t
const & problem_spec,
232 Options_t
const &
options = Options_t())
234 trees_(treeCount, DecisionTree_t(problem_spec)),
235 ext_param_(problem_spec),
241 for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243 trees_[k].topology_ = *topology_begin;
244 trees_[k].parameters_ = *parameter_begin;
262 vigra_precondition(ext_param_.used() ==
true,
263 "RandomForest::ext_param(): "
264 "Random forest has not been trained yet.");
281 vigra_precondition(ext_param_.used() ==
false,
282 "RandomForest::set_ext_param():"
283 "Random forest has been trained! Call reset()"
284 "before specifying new extrinsic parameters.");
308 DecisionTree_t
const &
tree(
int index)
const
310 return trees_[index];
315 DecisionTree_t &
tree(
int index)
317 return trees_[index];
325 return ext_param_.column_count_;
336 return ext_param_.column_count_;
344 return ext_param_.class_count_;
351 return options_.tree_count_;
392 template <
class U,
class C1,
403 Random_t
const & random);
405 template <
class U,
class C1,
426 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
427 void learn( MultiArrayView<2, U, C1>
const & features,
428 MultiArrayView<2, U2,C2>
const & labels,
438 template <
class U,
class C1,
class U2,
class C2,
439 class Visitor_t,
class Split_t>
440 void learn( MultiArrayView<2, U, C1>
const & features,
441 MultiArrayView<2, U2,C2>
const & labels,
470 template <
class U,
class C1,
class U2,
class C2>
482 template<
class U,
class C1,
495 bool adjust_thresholds=
false);
497 template <
class U,
class C1,
class U2,
class C2>
502 onlineLearn(features,
512 template<
class U,
class C1,
518 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
519 MultiArrayView<2,U2,C2>
const & response,
526 template<
class U,
class C1,
class U2,
class C2>
527 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
528 MultiArrayView<2, U2, C2>
const & labels,
531 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
561 template <
class U,
class C,
class Stop>
562 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features, Stop & stop)
const;
564 template <
class U,
class C>
565 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features)
575 template <
class U,
class C>
576 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
577 ArrayVectorView<double> prior)
const;
589 template <
class U,
class C1,
class T,
class C2>
593 vigra_precondition(features.
shape(0) == labels.
shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(
int k=0; k<features.
shape(0); ++k)
597 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
613 template <
class U,
class C1,
class T,
class C2>
616 LabelType nanLabel)
const
618 vigra_precondition(features.
shape(0) == labels.
shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(
int k=0; k<features.
shape(0); ++k)
622 if(detail::contains_nan(
rowVector(features, k)))
623 labels(k,0) = nanLabel;
638 template <
class U,
class C1,
class T,
class C2,
class Stop>
643 vigra_precondition(features.
shape(0) == labels.
shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(
int k=0; k<features.
shape(0); ++k)
660 template <
class U,
class C1,
class T,
class C2,
class Stop>
664 template <
class T1,
class T2,
class C>
674 template <
class U,
class C1,
class T,
class C2>
681 template <
class U,
class C1,
class T,
class C2>
691template <
class LabelType,
class PreprocessorTag>
692template<
class U,
class C1,
698void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
699 MultiArrayView<2,U2,C2>
const & response,
705 bool adjust_thresholds)
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
712 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713 typedef UniformIntRandomFunctor<Random_t>
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721 Default_Stop_t default_stop(options_);
722 typename RF_CHOOSER(Stop_t)::type stop
723 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724 Default_Split_t default_split;
725 typename RF_CHOOSER(Split_t)::type split
726 = RF_CHOOSER(Split_t)::choose(split_, default_split);
727 rf::visitors::StopVisiting stopvisiting;
728 typedef rf::visitors::detail::VisitorNode
729 <rf::visitors::OnlineLearnVisitor,
730 typename RF_CHOOSER(Visitor_t)::type>
733 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
735 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
741 ext_param_.class_count_=0;
742 Preprocessor_t preprocessor( features, response,
743 options_, ext_param_);
746 RandFunctor_t randint ( random);
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
754 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
760 for(
int ii = 0; ii < static_cast<int>(trees_.
size()); ++ii)
762 online_visitor_.tree_id=ii;
763 poisson_sampler.sample();
764 std::map<int,int> leaf_parents;
765 leaf_parents.clear();
767 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
769 int sample=poisson_sampler[s];
770 online_visitor_.current_label=preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
776 online_visitor_.add_to_index_list(ii,leaf,sample);
779 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
781 leaf_parents[leaf]=online_visitor_.last_node_id;
786 std::map<int,int>::iterator leaf_iterator;
787 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
789 int leaf=leaf_iterator->first;
790 int parent=leaf_iterator->second;
791 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
794 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795 StackEntry_t stack_entry(indeces.begin(),
797 ext_param_.class_count_);
802 if(
NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
804 stack_entry.leftParent=parent;
808 vigra_assert(
NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
809 stack_entry.rightParent=parent;
813 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
815 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
828 online_visitor_.deactivate();
831template<
class LabelType,
class PreprocessorTag>
832template<
class U,
class C1,
853 ext_param_.class_count_=0;
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
862 Default_Stop_t default_stop(options_);
863 typename RF_CHOOSER(Stop_t)::type stop
864 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
865 Default_Split_t default_split;
866 typename RF_CHOOSER(Split_t)::type split
867 = RF_CHOOSER(Split_t)::choose(split_, default_split);
871 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
873 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
875 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
879 RandFunctor_t randint ( random);
885 Preprocessor_t preprocessor( features, response,
886 options_, ext_param_);
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
897 preprocessor.strata().end(),
898 detail::make_sampler_opt(options_)
899 .sampleSize(
ext_param().actual_msample_),
908 ext_param_.class_count_);
912 online_visitor_.reset_tree(treeId);
913 online_visitor_.tree_id=treeId;
914 trees_[treeId].reset();
916 .learn( preprocessor.features(),
917 preprocessor.response(),
924 .visit_after_tree( *
this,
930 online_visitor_.deactivate();
933template <
class LabelType,
class PreprocessorTag>
934template <
class U,
class C1,
946 Random_t
const & random)
957 vigra_precondition(features.
shape(0) == response.
shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
966 Default_Stop_t default_stop(options_);
967 typename RF_CHOOSER(Stop_t)::type stop
968 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
969 Default_Split_t default_split;
970 typename RF_CHOOSER(Split_t)::type split
971 = RF_CHOOSER(Split_t)::choose(split_, default_split);
975 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
977 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
982 online_visitor_.deactivate();
986 RandFunctor_t randint ( random);
993 Preprocessor_t preprocessor( features, response,
994 options_, ext_param_);
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1005 preprocessor.strata().end(),
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(
ext_param().actual_msample_),
1010 visitor.visit_at_beginning(*
this, preprocessor);
1013 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1021 ext_param_.class_count_);
1026 .learn( preprocessor.features(),
1027 preprocessor.response(),
1034 .visit_after_tree( *
this,
1041 visitor.visit_at_end(*
this, preprocessor);
1043 online_visitor_.deactivate();
1049template <
class LabelType,
class Tag>
1050template <
class U,
class C,
class Stop>
1054 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():"
1056 " Too few columns in feature matrix.");
1057 vigra_precondition(
rowCount(features) == 1,
1058 "RandomForestn::predictLabel():"
1059 " Feature matrix must have a singlerow.");
1063 ext_param_.to_classlabel(
argMax(probabilities), d);
1069template <
class LabelType,
class PreprocessorTag>
1070template <
class U,
class C>
1075 using namespace functor;
1076 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(
rowCount(features) == 1,
1079 "RandomForestn::predictLabel():"
1080 " Feature matrix must have a single row.");
1083 std::transform( prob.begin(), prob.end(),
1084 priors.
begin(), prob.begin(),
1087 ext_param_.to_classlabel(
argMax(prob), d);
1091template<
class LabelType,
class PreprocessorTag>
1092template <
class T1,
class T2,
class C>
1101 "RandomFroest::predictProbabilities():"
1102 " Feature matrix and probability matrix size mismatch.");
1105 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():"
1107 " Too few columns in feature matrix.");
1110 "RandomForestn::predictProbabilities():"
1111 " Probability matrix must have as many columns as there are classes.");
1114 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1117 for(
int k=0; k<options_.tree_count_; ++k)
1119 set_id=(set_id+1) % predictionSet.indices[0].size();
1120 typedef std::set<SampleRange<T1> > my_set;
1121 typedef typename my_set::iterator set_it;
1124 std::vector<std::pair<int,set_it> > stack;
1126 for(set_it i=predictionSet.ranges[set_id].begin();
1127 i!=predictionSet.ranges[set_id].end();++i)
1128 stack.push_back(std::pair<int,set_it>(2,i));
1130 int num_decisions=0;
1131 while(!stack.empty())
1133 set_it range=stack.back().second;
1134 int index=stack.back().first;
1138 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1140 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1141 trees_[k].parameters_,
1142 index).prob_begin();
1143 for(
int i=range->start;i!=range->end;++i)
1146 for(
int l=0; l<ext_param_.class_count_; ++l)
1148 prob(predictionSet.indices[set_id][i], l) +=
static_cast<T2
>(weights[l]);
1150 totalWeights[predictionSet.indices[set_id][i]] +=
static_cast<T1
>(weights[l]);
1157 if(trees_[k].topology_[index]!=i_ThresholdNode)
1159 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1161 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1168 if(range->max_boundaries[node.column()]<node.threshold())
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1175 SampleRange<T1> new_range=*range;
1176 new_range.min_boundaries[node.column()]=FLT_MAX;
1177 range->max_boundaries[node.column()]=-FLT_MAX;
1178 new_range.start=new_range.end=range->end;
1180 while(i!=range->end)
1183 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1185 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1189 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1200 if(range->start==range->end)
1202 predictionSet.ranges[set_id].erase(range);
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1209 if(new_range.start!=new_range.end)
1211 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1216 predictionSet.cumulativePredTime[k]=num_decisions;
1218 for(
unsigned int i=0;i<totalWeights.size();++i)
1222 for(
int l=0; l<ext_param_.class_count_; ++l)
1225 prob(i, l) /= totalWeights[i];
1227 assert(test==totalWeights[i]);
1228 assert(totalWeights[i]>0.0);
1232template <
class LabelType,
class PreprocessorTag>
1233template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1237 Stop_t & stop_)
const
1243 "RandomForestn::predictProbabilities():"
1244 " Feature matrix and probability matrix size mismatch.");
1248 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():"
1250 " Too few columns in feature matrix.");
1253 "RandomForestn::predictProbabilities():"
1254 " Probability matrix must have as many columns as there are classes.");
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257 Default_Stop_t default_stop(options_);
1258 typename RF_CHOOSER(Stop_t)::type & stop
1259 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1261 stop.set_external_parameters(ext_param_, tree_count());
1262 prob.init(NumericTraits<T>::zero());
1272 for(
int row=0; row <
rowCount(features); ++row)
1278 if(detail::contains_nan(currentRow))
1284 ArrayVector<double>::const_iterator weights;
1287 double totalWeight = 0.0;
1290 for(
int k=0; k<options_.tree_count_; ++k)
1293 weights = trees_[k ].predict(currentRow);
1296 int weighted = options_.predict_weighted_;
1297 for(
int l=0; l<ext_param_.class_count_; ++l)
1299 double cur_w = weights[l] * (weighted * (*(weights-1))
1301 prob(row, l) +=
static_cast<T
>(cur_w);
1303 totalWeight += cur_w;
1305 if(stop.after_prediction(weights,
1315 for(
int l=0; l< ext_param_.class_count_; ++l)
1317 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1323template <
class LabelType,
class PreprocessorTag>
1324template <
class U,
class C1,
class T,
class C2>
1333 "RandomForestn::predictProbabilities():"
1334 " Feature matrix and probability matrix size mismatch.");
1338 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():"
1340 " Too few columns in feature matrix.");
1343 "RandomForestn::predictProbabilities():"
1344 " Probability matrix must have as many columns as there are classes.");
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347 prob.init(NumericTraits<T>::zero());
1357 for(
int row=0; row <
rowCount(features); ++row)
1359 ArrayVector<double>::const_iterator weights;
1362 for(
int k=0; k<options_.tree_count_; ++k)
1365 weights = trees_[k ].predict(
rowVector(features, row));
1368 int weighted = options_.predict_weighted_;
1369 for(
int l=0; l<ext_param_.class_count_; ++l)
1371 double cur_w = weights[l] * (weighted * (*(weights-1))
1373 prob(row, l) +=
static_cast<T
>(cur_w);
1377 prob/= options_.tree_count_;
1383#include "random_forest/rf_algorithm.hxx"
Definition array_vector.hxx:77
const_iterator begin() const
Definition array_vector.hxx:223
size_type size() const
Definition array_vector.hxx:358
const_iterator end() const
Definition array_vector.hxx:237
Definition array_vector.hxx:514
Definition rf_region.hxx:58
Standard early stopping criterion.
Definition rf_common.hxx:886
Base class for, and view to, MultiArray.
Definition multi_array.hxx:705
const difference_type & shape() const
Definition multi_array.hxx:1650
MultiArrayView & init(const U &init)
Definition multi_array.hxx:1208
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
Definition rf_nodeproxy.hxx:88
problem specification class for the random forest.
Definition rf_common.hxx:539
Definition rf_preprocessing.hxx:63
Options object for the random forest.
Definition rf_common.hxx:171
Random forest version 2 (see also RandomForest for version 3)
Definition random_forest.hxx:148
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition random_forest.hxx:228
Options_t const & options() const
access const random forest options
Definition random_forest.hxx:301
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition random_forest.hxx:941
DecisionTree_t const & tree(int index) const
access const trees
Definition random_forest.hxx:308
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition random_forest.hxx:278
DecisionTree_t & tree(int index)
access trees
Definition random_forest.hxx:315
int tree_count() const
return number of trees
Definition random_forest.hxx:349
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition random_forest.hxx:197
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
int column_count() const
return number of features used while training.
Definition random_forest.hxx:334
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition random_forest.hxx:614
int feature_count() const
return number of features used while training.
Definition random_forest.hxx:323
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition random_forest.hxx:590
int class_count() const
return number of classes used while training.
Definition random_forest.hxx:342
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition random_forest.hxx:1052
Options_t & set_options()
access random forest options
Definition random_forest.hxx:291
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition random_forest.hxx:260
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition random_forest.hxx:838
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition random_forest.hxx:471
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition random_forest.hxx:639
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition random_forest.hxx:675
Definition random.hxx:346
Create random samples from a sequence of indices.
Definition sampling.hxx:233
IndexArrayViewType sampledIndices() const
Definition sampling.hxx:435
void sample()
Definition sampling.hxx:467
IndexArrayViewType oobIndices() const
Definition sampling.hxx:443
Definition matrix.hxx:125
Definition rf_visitors.hxx:585
Definition rf_visitors.hxx:236
Definition rf_visitors.hxx:256
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition algorithm.hxx:96
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition rf_common.hxx:131