public class DecisionTree
extends java.lang.Object
implements scala.Serializable
param: strategy The configuration parameters for the tree algorithm which specify the type of decision tree (classification or regression), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. param: seed Random seed.
Constructor and Description |
---|
DecisionTree(Strategy strategy) |
Modifier and Type | Method and Description |
---|---|
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
protected static boolean |
isTraceEnabled() |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
DecisionTreeModel |
run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth,
int numClasses)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth,
int numClasses,
int maxBins,
scala.Enumeration.Value quantileCalculationStrategy,
scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
Strategy strategy)
Method to train a decision tree model.
|
static DecisionTreeModel |
trainClassifier(JavaRDD<LabeledPoint> input,
int numClasses,
java.util.Map<java.lang.Integer,java.lang.Integer> categoricalFeaturesInfo,
java.lang.String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainClassifier(RDD<LabeledPoint> input,
int numClasses,
scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo,
java.lang.String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for binary or multiclass classification.
|
static DecisionTreeModel |
trainRegressor(JavaRDD<LabeledPoint> input,
java.util.Map<java.lang.Integer,java.lang.Integer> categoricalFeaturesInfo,
java.lang.String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainRegressor(RDD<LabeledPoint> input,
scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo,
java.lang.String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for regression.
|
public DecisionTree(Strategy strategy)
strategy
- The configuration parameters for the tree algorithm which specify the type
of decision tree (classification or regression), feature type (continuous,
categorical), depth of the tree, quantile calculation strategy, etc.public static DecisionTreeModel train(RDD<LabeledPoint> input, Strategy strategy)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.strategy
- The configuration parameters for the tree algorithm which specify the type
of decision tree (classification or regression), feature type (continuous,
categorical), depth of the tree, quantile calculation strategy, etc.public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- Type of decision tree, either classification or regression.impurity
- Criterion used for information gain calculation.maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- Type of decision tree, either classification or regression.impurity
- Criterion used for information gain calculation.maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).numClasses
- Number of classes for classification. Default value of 2.public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, scala.Enumeration.Value quantileCalculationStrategy, scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- Type of decision tree, either classification or regression.impurity
- Criterion used for information gain calculation.maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).numClasses
- Number of classes for classification. Default value of 2.maxBins
- Maximum number of bins used for splitting features.quantileCalculationStrategy
- Algorithm for calculating quantiles.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n -> k)
indicates that feature n is categorical with k categories
indexed from 0: {0, 1, ..., k-1}.public static DecisionTreeModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo, java.lang.String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels should take values {0, 1, ..., numClasses-1}.numClasses
- Number of classes for classification.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n -> k)
indicates that feature n is categorical with k categories
indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
Supported values: "gini" (recommended) or "entropy".maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).
(suggested value: 5)maxBins
- Maximum number of bins used for splitting features.
(suggested value: 32)public static DecisionTreeModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<java.lang.Integer,java.lang.Integer> categoricalFeaturesInfo, java.lang.String impurity, int maxDepth, int maxBins)
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
input
- (undocumented)numClasses
- (undocumented)categoricalFeaturesInfo
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)public static DecisionTreeModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeaturesInfo, java.lang.String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels are real numbers.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n -> k)
indicates that feature n is categorical with k categories
indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
The only supported value for regression is "variance".maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).
(suggested value: 5)maxBins
- Maximum number of bins used for splitting features.
(suggested value: 32)public static DecisionTreeModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<java.lang.Integer,java.lang.Integer> categoricalFeaturesInfo, java.lang.String impurity, int maxDepth, int maxBins)
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
input
- (undocumented)categoricalFeaturesInfo
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
public DecisionTreeModel run(RDD<LabeledPoint> input)
input
- Training data: RDD of LabeledPoint
.