public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable, LogisticRegressionParams, HasTrainingSummary<LogisticRegressionTrainingSummary>
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
IntParam |
aggregationDepth()
Param for suggested depth for treeAggregate (>= 2).
|
BinaryLogisticRegressionTrainingSummary |
binarySummary()
Gets summary of model on training set.
|
Matrix |
coefficientMatrix() |
Vector |
coefficients()
A vector of model coefficients for "binomial" logistic regression.
|
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
DoubleParam |
elasticNetParam()
Param for the ElasticNet mixing parameter, in range [0, 1].
|
LogisticRegressionSummary |
evaluate(Dataset<?> dataset)
Evaluates the model on a test dataset.
|
Param<String> |
family()
Param for the name of family which is a description of the label distribution
to be used in the model.
|
BooleanParam |
fitIntercept()
Param for whether to fit an intercept term.
|
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
double |
intercept()
The model intercept for "binomial" logistic regression.
|
Vector |
interceptVector() |
static LogisticRegressionModel |
load(String path) |
Param<Matrix> |
lowerBoundsOnCoefficients()
The lower bounds on coefficients if fitting under bound constrained optimization.
|
Param<Vector> |
lowerBoundsOnIntercepts()
The lower bounds on intercepts if fitting under bound constrained optimization.
|
IntParam |
maxIter()
Param for maximum number of iterations (>= 0).
|
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given feature vector.
|
static MLReader<LogisticRegressionModel> |
read() |
DoubleParam |
regParam()
Param for regularization parameter (>= 0).
|
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
BooleanParam |
standardization()
Param for whether to standardize the training features before fitting the model.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
DoubleParam |
threshold()
Param for threshold in binary classification prediction, in range [0, 1].
|
DoubleParam |
tol()
Param for the convergence tolerance for iterative algorithms (>= 0).
|
String |
toString() |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<Matrix> |
upperBoundsOnCoefficients()
The upper bounds on coefficients if fitting under bound constrained optimization.
|
Param<Vector> |
upperBoundsOnIntercepts()
The upper bounds on intercepts if fitting under bound constrained optimization.
|
Param<String> |
weightCol()
Param for weight column name.
|
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, probabilityCol, setProbabilityCol, thresholds, transform, transformSchema
rawPredictionCol, setRawPredictionCol, transformImpl
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionCol
transform, transform, transform
params
save
checkThresholdConsistency, getFamily, getLowerBoundsOnCoefficients, getLowerBoundsOnIntercepts, getUpperBoundsOnCoefficients, getUpperBoundsOnIntercepts, usingBoundConstrainedOptimization, validateAndTransformSchema
extractInstances
extractInstances, extractInstances
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
getRawPredictionCol, rawPredictionCol
getProbabilityCol, probabilityCol
thresholds
getRegParam
getElasticNetParam
getMaxIter
getFitIntercept
getStandardization
getWeightCol
getAggregationDepth
hasSummary, setSummary
initializeForcefully, initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(String path)
public final Param<String> family()
LogisticRegressionParams
family
in interface LogisticRegressionParams
public Param<Matrix> lowerBoundsOnCoefficients()
LogisticRegressionParams
lowerBoundsOnCoefficients
in interface LogisticRegressionParams
public Param<Matrix> upperBoundsOnCoefficients()
LogisticRegressionParams
upperBoundsOnCoefficients
in interface LogisticRegressionParams
public Param<Vector> lowerBoundsOnIntercepts()
LogisticRegressionParams
lowerBoundsOnIntercepts
in interface LogisticRegressionParams
public Param<Vector> upperBoundsOnIntercepts()
LogisticRegressionParams
upperBoundsOnIntercepts
in interface LogisticRegressionParams
public final IntParam aggregationDepth()
HasAggregationDepth
aggregationDepth
in interface HasAggregationDepth
public DoubleParam threshold()
HasThreshold
threshold
in interface HasThreshold
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final BooleanParam standardization()
HasStandardization
standardization
in interface HasStandardization
public final DoubleParam tol()
HasTol
public final BooleanParam fitIntercept()
HasFitIntercept
fitIntercept
in interface HasFitIntercept
public final IntParam maxIter()
HasMaxIter
maxIter
in interface HasMaxIter
public final DoubleParam elasticNetParam()
HasElasticNetParam
elasticNetParam
in interface HasElasticNetParam
public final DoubleParam regParam()
HasRegParam
regParam
in interface HasRegParam
public String uid()
Identifiable
uid
in interface Identifiable
public Matrix coefficientMatrix()
public Vector interceptVector()
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,LogisticRegressionModel>
public Vector coefficients()
public double intercept()
public LogisticRegressionModel setThreshold(double value)
LogisticRegressionParams
If the estimated probability of class label 1 is greater than threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
setThreshold
in interface LogisticRegressionParams
value
- (undocumented)public double getThreshold()
LogisticRegressionParams
If thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns `threshold` if set, or its default value if unset.
@group getParam
@throws IllegalArgumentException if `thresholds` is set to an array of length other than 2.getThreshold
in interface LogisticRegressionParams
getThreshold
in interface HasThreshold
public LogisticRegressionModel setThresholds(double[] value)
LogisticRegressionParams
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in interface LogisticRegressionParams
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
LogisticRegressionParams
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
getThresholds
in interface LogisticRegressionParams
getThresholds
in interface HasThresholds
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
hasSummary
is false.summary
in interface HasTrainingSummary<LogisticRegressionTrainingSummary>
public BinaryLogisticRegressionTrainingSummary binarySummary()
hasSummary
is false or it is a multiclass model.public LogisticRegressionSummary evaluate(Dataset<?> dataset)
dataset
- Test dataset to evaluate model on.public double predict(Vector features)
thresholds
.predict
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Model<LogisticRegressionModel>
extra
- (undocumented)public MLWriter write()
MLWriter
instance for this ML instance.
For LogisticRegressionModel
, this does NOT currently save the training summary
.
An option to save summary
may be added in the future.
This also does not save the parent
currently.
write
in interface MLWritable
public String toString()
toString
in interface Identifiable
toString
in class Object