public final class GBTClassificationModel extends PredictionModel<Vector,GBTClassificationModel> implements scala.Serializable
Gradient-Boosted Trees (GBTs)
model for classification.
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.
param: _trees Decision trees in the ensemble.
param: _treeWeights Weights for the decision trees in the ensemble.Constructor and Description |
---|
GBTClassificationModel(java.lang.String uid,
DecisionTreeRegressionModel[] _trees,
double[] _treeWeights)
Construct a GBTClassificationModel
|
Modifier and Type | Method and Description |
---|---|
GBTClassificationModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
static GBTClassificationModel |
fromOld(GradientBoostedTreesModel oldModel,
GBTClassifier parent,
scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeatures,
int numFeatures)
(private[ml]) Convert a model from the old API
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
protected double |
predict(Vector features)
Predict label for the given features.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
java.lang.String |
toString() |
protected DataFrame |
transformImpl(DataFrame dataset) |
org.apache.spark.ml.tree.DecisionTreeModel[] |
trees() |
double[] |
treeWeights() |
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
featuresDataType, setFeaturesCol, setPredictionCol, transform, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public GBTClassificationModel(java.lang.String uid, DecisionTreeRegressionModel[] _trees, double[] _treeWeights)
_trees
- Decision trees in the ensemble._treeWeights
- Weights for the decision trees in the ensemble.uid
- (undocumented)public static GBTClassificationModel fromOld(GradientBoostedTreesModel oldModel, GBTClassifier parent, scala.collection.immutable.Map<java.lang.Object,java.lang.Object> categoricalFeatures, int numFeatures)
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,GBTClassificationModel>
public org.apache.spark.ml.tree.DecisionTreeModel[] trees()
public double[] treeWeights()
protected DataFrame transformImpl(DataFrame dataset)
transformImpl
in class PredictionModel<Vector,GBTClassificationModel>
protected double predict(Vector features)
PredictionModel
transform()
and output predictionCol
.predict
in class PredictionModel<Vector,GBTClassificationModel>
features
- (undocumented)public GBTClassificationModel copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Model<GBTClassificationModel>
extra
- (undocumented)defaultCopy()
public java.lang.String toString()
toString
in interface Identifiable
toString
in class java.lang.Object
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()