/** * Method to train a decision tree model for binary or multiclass classification. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. * @param numClasses number of classes for classification. * @param categoricalFeaturesInfo Map storing arity of categorical features. * E.g., an entry (n -> k) indicates that feature n is categorical * with k categories indexed from 0: {0, 1, ..., k-1}. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. * Supported: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * (suggested value: 4) * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], numTrees: Int, featureSubsetStrategy: String, impurity: String, maxDepth: Int, maxBins: Int, seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClasses, maxBins, Sort, categoricalFeaturesInfo) //调用的是重载的另外一个 trainClassifier trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed) }
/** * Method to train a decision tree model for binary or multiclass classification. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * Labels should take values {0, 1, ..., numClasses-1}. * @param strategy Parameters for training each tree in the forest. * @param numTrees Number of trees in the random forest. * @param featureSubsetStrategy Number of features to consider for splits at each node. * Supported: "auto", "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, seed: Int): RandomForestModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") //在该方法中创建 RandomForest 对象 val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) //再调用其 run 方法,传入的参数是类型 RDD[LabeledPoint],方法返回的是 RandomForestModel 实例 rf.run(input) }
/** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return a random forest model that can be used for prediction */ def run(input: RDD[LabeledPoint]): RandomForestModel = {
// Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") //找到切分点(splits)及箱子信息(Bins) //对于连续型特征,利用切分点抽样统计简化计算 //对于名称型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分 //如果是有序的,则最多有 splits=numBins-1 个划分 val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n"))
// Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. //转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中 val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
val withReplacement = if (numTrees > 1) true else false
// depth of the decision tree val maxDepth = strategy.maxDepth require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
// Max memory usage for aggregates // TODO: Calculate memory usage more precisely. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val maxMemoryPerNode = { val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. Some(metadata.numBins.zipWithIndex.sortBy(- _._1) .take(metadata.numFeaturesPerNode).map(_._2)) } else { None } //计算聚合操作时节点的内存 RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L } require(maxMemoryPerNode <= maxMemoryUsage, s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + " which is too small for the given features." + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
timer.stop("init")
/* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). * Each data sample is handled by a particular node (or it reaches a leaf and is not used * in lower levels). */
// Create an RDD of node Id cache. // At first, all the rows belong to the root nodes (node Id == 1). //节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去 val nodeIdCache = if (strategy.useNodeIdCache) { Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { None }
// FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]()
val rng = new scala.util.Random() rng.setSeed(seed)
while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. // 取得每个树所有需要切分的节点 val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.size > 0, s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
// Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") //找出最优切点 DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") }
baggedInput.unpersist()
timer.stop("total")
logInfo("Internal timing for DecisionTree:") logInfo(s"$timer")
// Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { try { nodeIdCache.get.deleteAllCheckpoints() } catch { case e: IOException => logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } }
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) new RandomForestModel(strategy.algo, trees) }
}
Show moreShow more icon
上面给出的是 RandomForest 类中的核心方法 run 的代码,在确定切分点及箱子信息的时候调用了 DecisionTree.findSplitsBins 方法,跳入该方法,可以看到如下代码:
/** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. * * Continuous features: * For each feature, there are numBins - 1 possible splits representing the possible binary * decisions at each node in the tree. * This finds locations (feature values) for splits using a subsample of the data. * * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: * (a) "unordered features" * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] * of size (numFeatures, numBins). */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
// Sample the input only if there are continuous features. // 判断特征中是否存在连续特征 val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) val sampledInput = if (hasContinuousFeatures) { // Calculate the number of samples for approximate quantile calculation. //采样样本数量,最少应该为 10000 个 val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) //计算采样比例 val fraction = if (requiredSamples < metadata.numExamples) { requiredSamples.toDouble / metadata.numExamples } else { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() } else { //如果为离散特征,则构建一个空数组(即无需采样) new Array[LabeledPoint](0) }
// //分裂点策略,目前 Spark 中只实现了一种策略:排序 Sort metadata.quantileStrategy match { case Sort => //每个特征分别对应一组切分点位置 val splits = new Array[Array[Split]](numFeatures) //存放切分点位置对应的箱子信息 val bins = new Array[Array[Bin]](numFeatures)
// Find all splits. // Iterate over all features. var featureIndex = 0 //遍历所有的特征 while (featureIndex < numFeatures) { //特征为连续的情况 if (metadata.isContinuous(featureIndex)) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) // findSplitsForContinuousFeature 返回连续特征的所有切分位置 val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
val numSplits = featureSplits.length //连续特征的箱子数为切分点个数+1 val numBins = numSplits + 1 logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
//切分点数组及特征箱子数组 splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0 //遍历切分点 while (splitIndex < numSplits) { //获取切分点对应的值,由于是排过序的,因此它具有阈值属性 val threshold = featureSplits(splitIndex) //保存对应特征所有的切分点位置信息 splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) splitIndex += 1 } //采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱 bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue)
splitIndex = 1 //除最后一个箱子外剩余箱子的计算,各箱子里将存放的是两个切分点位置阈值区间的属性值 while (splitIndex < numSplits) { bins(featureIndex)(splitIndex) = new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), Continuous, Double.MinValue) splitIndex += 1 } //最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置 bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } else { //特征为离散情况时的计算 val numSplits = metadata.numSplits(featureIndex) val numBins = metadata.numBins(featureIndex) // Categorical feature //离线属性的个数 val featureArity = metadata.featureArity(featureIndex) //特征无序时的处理方式 if (metadata.isUnordered(featureIndex)) { // Unordered features // 2^(maxFeatureValue - 1) - 1 combinations splits(featureIndex) = new Array[Split](numSplits) var splitIndex = 0 while (splitIndex < numSplits) { //提取特征的属性值,返回集合包含其中一个或多个的离散属性值 val categories: List[Double] = extractMultiClassCategories(splitIndex + 1, featureArity) splits(featureIndex)(splitIndex) = new Split(featureIndex, Double.MinValue, Categorical, categories) splitIndex += 1 } } else { //有序特征无需处理,箱子与特征值对应 // Ordered features // Bins correspond to feature values, so we do not need to compute splits or bins // beforehand. Splits are constructed as needed during training. splits(featureIndex) = new Array[Split](0) } // For ordered features, bins correspond to feature values. // For unordered categorical features, there is no need to construct the bins. // since there is a one-to-one correspondence between the splits and the bins. bins(featureIndex) = new Array[Bin](0) } featureIndex += 1 } (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => throw new UnsupportedOperationException("approximate histogram not supported yet.") } }
/** * Find the best split for a node. * @param binAggregates Bin statistics. * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑 splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], node: Node): (Split, InformationGainStats, Predict) = {
// calculate predict and impurity if current node is top node val level = Node.indexToLevel(node.id) var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { None } else { Some((node.predict, node.impurity)) }
// For each (feature, split), calculate the gain, and select the best (feature, split). //对各特征及切分点,计算其信息增益并从中选择最优 (feature, split) val (bestSplit, bestSplitStats) = Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => val featureIndex = if (featuresForNode.nonEmpty) { featuresForNode.get.apply(featureIndexIdx) } else { featureIndexIdx } val numSplits = binAggregates.metadata.numSplits(featureIndex) //特征为连续值的情况 if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } // Find best split. val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => //计算 leftChild 及 rightChild 子节点的 impurity val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) //求 impurity 的预测值,采用的是平均值计算 predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) //求信息增益 information gain 值,用于评估切分点是否最优 val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { //无序离散特征时的情况 // Unordered categorical feature val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { //有序离散特征时的情况 // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value). * The bins are ordered based on centroidForCategories, and this ordering determines which * splits are considered. (With K categories, we consider K - 1 possible splits.) * * centroidForCategories is a list: (category, centroid) */ //多元分类时的情况 val centroidForCategories = if (binAggregates.metadata.isMulticlass) { // For categorical variables in multiclass classification, // the bins are ordered by the impurity of their corresponding labels. Range(0, numBins).map { case featureValue => val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { // impurity 求的就是均方差 categoryStats.calculate() } else { Double.MaxValue } (featureValue, centroid) } } else { // 回归或二元分类时的情况 regression or binary classification // For categorical variables in regression and binary classification, // the bins are ordered by the centroid of their corresponding labels. Range(0, numBins).map { case featureValue => val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { //求的就是平均值作为 impurity categoryStats.predict } else { Double.MaxValue } (featureValue, centroid) } }
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
// bins sorted by centroids val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
logDebug("Sorted centroids for categorical variable = " + categoriesSortedByCentroid.mkString(","))
// Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. var splitIndex = 0 while (splitIndex < numSplits) { val currentCategory = categoriesSortedByCentroid(splitIndex)._1 val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 //将两个箱子的状态信息进行合并 binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) splitIndex += 1 } // lastCategory = index of bin with total aggregates for this (node, feature) val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. //通过信息增益值选择最优切分点 val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val featureValue = categoriesSortedByCentroid(splitIndex)._1 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) val bestFeatureSplit = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain)
/** * :: Experimental :: * Represents a random forest model. * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles */ // RandomForestModel 扩展自 TreeEnsembleModel @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable {
private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion //将训练好的模型加载到内存 override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) assert(metadata.treeWeights.forall(_ == 1.0)) val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) new RandomForestModel(Algo.fromString(metadata.algo), trees) case _ => throw new Exception(s"RandomForestModel.load did not recognize model" + s" with (className, format version): ($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } }
private object SaveLoadV1_0 { // Hard-code class name string in case it changes in the future def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel" }
}
Show moreShow more icon
在利用随机森林进行预测时,调用的 predict 方法扩展自 TreeEnsembleModel,它是树结构组合模型的表示,除随机森林外还包括 Gradient-Boosted Trees (GBTs),其部分核心代码如下:
/** * Represents a tree ensemble model. * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles * @param treeWeights tree ensemble weights * @param combiningStrategy strategy for combining the predictions, not used for regression. */ private[tree] sealed class TreeEnsembleModel( protected val algo: Algo, protected val trees: Array[DecisionTreeModel], protected val treeWeights: Array[Double], protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") //其它代码省略
//通过投票实现最终的分类 /** * Classifies a single data point based on (weighted) majority votes. */ private def predictByVoting(features: Vector): Double = { val votes = mutable.Map.empty[Int, Double] trees.view.zip(treeWeights).foreach { case (tree, weight) => val prediction = tree.predict(features).toInt votes(prediction) = votes.getOrElse(prediction, 0.0) + weight } votes.maxBy(_._2)._1 }
/** * Predict values for a single data point using the model trained. * * @param features array representing a single data point * @return predicted category from the trained model */ //不同的策略采用不同的预测方法 def findSplitsBins(features: Vector): Double = { (algo, combiningStrategy) match { case (Regression, Sum) => predictBySumming(features) case (Regression, Average) => predictBySumming(features) / sumWeights case (Classification, Sum) => // binary classification val prediction = predictBySumming(features) // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. if (prediction > 0.0) 1.0 else 0.0 //随机森林对应 predictByVoting 方法 case (Classification, Vote) => predictByVoting(features) case _ => throw new IllegalArgumentException( "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + s"($algo, $combiningStrategy).") } }
// predict 方法的具体实现 /** * Predict values for the given data set. * * @param features RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => findSplitsBins (x))
object RandomForstExample { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("RandomForestExample"). setMaster("spark://sparkmaster:7077") val sc = new SparkContext(sparkConf)
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "/data/sample_data.txt")
val numClasses = 2 val featureSubsetStrategy = "auto" val numTrees = 3 val model: RandomForestModel =RandomForest.trainClassifier( data, Strategy.defaultStrategy("classification"),numTrees, featureSubsetStrategy,new java.util.Random().nextInt())
val input: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "/data/input.txt")
val predictResult = input.map { point => val prediction = model.predict(point.features) (point.label, prediction) } //打印输出结果,在 spark-shell 上执行时使用 predictResult.collect() //将结果保存到 hdfs //predictResult.saveAsTextFile("/data/predictResult") sc.stop()
Reprint policy:
All articles in this blog are used except for special statements
CC BY 4.0
reprint policy. If reproduced, please indicate source
John Doe
!