public class LinearRegression extends Predictor<FeaturesType,Learner,M> implements LinearRegressionParams, DefaultParamsWritable, Logging
The learning objective is to minimize the specified loss function, with regularization. This supports two kinds of loss: - squaredError (a.k.a squared loss) - huber (a hybrid of squared error for relatively small errors and absolute error for relatively large ones, and we estimate the scale parameter from training data)
This supports multiple types of regularization: - none (a.k.a. ordinary least squares) - L2 (ridge regression) - L1 (Lasso) - L2 + L1 (elastic net)
The squared error objective function is:
$$ \begin{align} \min_{w}\frac{1}{2n}{\sum_{i=1}^n(X_{i}w - y_{i})^{2} + \lambda\left[\frac{1-\alpha}{2}{||w||_{2}}^{2} + \alpha{||w||_{1}}\right]} \end{align} $$
The huber objective function is:
$$ \begin{align} \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma + H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + \frac{1}{2}\lambda {||w||_2}^2} \end{align} $$
where
$$ \begin{align} H_m(z) = \begin{cases} z^2, & \text {if } |z| < \epsilon, \\ 2\epsilon|z| - \epsilon^2, & \text{otherwise} \end{cases} \end{align} $$
Note: Fitting with huber loss only supports none and L2 regularization.
Constructor and Description |
---|
LinearRegression() |
LinearRegression(String uid) |
Modifier and Type | Method and Description |
---|---|
LinearRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
static LinearRegression |
load(String path) |
static int |
MAX_FEATURES_FOR_NORMAL_SOLVER()
When using
LinearRegression.solver == "normal", the solver must limit the number of
features to at most this number. |
static MLReader<T> |
read() |
LinearRegression |
setAggregationDepth(int value)
Suggested depth for treeAggregate (greater than or equal to 2).
|
LinearRegression |
setElasticNetParam(double value)
Set the ElasticNet mixing parameter.
|
LinearRegression |
setEpsilon(double value)
Sets the value of param
epsilon . |
LinearRegression |
setFitIntercept(boolean value)
Set if we should fit the intercept.
|
LinearRegression |
setLoss(String value)
Sets the value of param
loss . |
LinearRegression |
setMaxIter(int value)
Set the maximum number of iterations.
|
LinearRegression |
setRegParam(double value)
Set the regularization parameter.
|
LinearRegression |
setSolver(String value)
Set the solver algorithm used for optimization.
|
LinearRegression |
setStandardization(boolean value)
Whether to standardize the training features before fitting the model.
|
LinearRegression |
setTol(double value)
Set the convergence tolerance of iterations.
|
LinearRegression |
setWeightCol(String value)
Whether to over-/under-sample training instances according to the given weights in weightCol.
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
epsilon, getEpsilon, loss, solver, validateAndTransformSchema
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
getRegParam, regParam
elasticNetParam, getElasticNetParam
getMaxIter, maxIter
fitIntercept, getFitIntercept
getStandardization, standardization
getWeightCol, weightCol
aggregationDepth, getAggregationDepth
write
save
initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public LinearRegression(String uid)
public LinearRegression()
public static LinearRegression load(String path)
public static int MAX_FEATURES_FOR_NORMAL_SOLVER()
LinearRegression.solver
== "normal", the solver must limit the number of
features to at most this number. The entire covariance matrix X^T^X will be collected
to the driver. This limit helps prevent memory overflow errors.public static MLReader<T> read()
public String uid()
Identifiable
uid
in interface Identifiable
public LinearRegression setRegParam(double value)
value
- (undocumented)public LinearRegression setFitIntercept(boolean value)
value
- (undocumented)public LinearRegression setStandardization(boolean value)
value
- (undocumented)public LinearRegression setElasticNetParam(double value)
Note: Fitting with huber loss only supports None and L2 regularization, so throws exception if this param is non-zero value.
value
- (undocumented)public LinearRegression setMaxIter(int value)
value
- (undocumented)public LinearRegression setTol(double value)
value
- (undocumented)public LinearRegression setWeightCol(String value)
value
- (undocumented)public LinearRegression setSolver(String value)
LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER
.
- "auto" (default) means that the solver algorithm is selected automatically.
The Normal Equations solver will be used when possible, but this will automatically fall
back to iterative optimization methods when needed.
Note: Fitting with huber loss doesn't support normal solver, so throws exception if this param was set with "normal".
value
- (undocumented)public LinearRegression setAggregationDepth(int value)
value
- (undocumented)public LinearRegression setLoss(String value)
loss
.
Default is "squaredError".
value
- (undocumented)public LinearRegression setEpsilon(double value)
epsilon
.
Default is 1.35.
value
- (undocumented)public LinearRegression copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,LinearRegression,LinearRegressionModel>
extra
- (undocumented)