public class BisectingKMeans extends Estimator<BisectingKMeansModel>
A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
by Steinbach, Karypis, and Kumar, with modification to fit Spark.
The algorithm starts from a single cluster that contains all points.
Iteratively it finds divisible clusters on the bottom level and bisects each of them using
k-means, until there are k
leaf clusters in total or no leaf clusters are divisible.
The bisecting steps of clusters on the same level are grouped together to increase parallelism.
If bisecting all divisible clusters on the bottom level would result more than k
leaf clusters,
larger clusters get higher priority.
http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf
Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
KDD Workshop on Text Mining, 2000.}
,
Serialized FormConstructor and Description |
---|
BisectingKMeans() |
BisectingKMeans(java.lang.String uid) |
Modifier and Type | Method and Description |
---|---|
protected static <T> T |
$(Param<T> param) |
static Params |
clear(Param<?> param) |
BisectingKMeans |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
protected static <T extends Params> |
copyValues(T to,
ParamMap extra) |
protected static <T extends Params> |
copyValues$default$2() |
protected static <T extends Params> |
defaultCopy(ParamMap extra) |
static java.lang.String |
explainParam(Param<?> param) |
static java.lang.String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<java.lang.String> |
featuresCol() |
BisectingKMeansModel |
fit(Dataset<?> dataset)
Fits a model to the input data.
|
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static java.lang.String |
getFeaturesCol() |
static int |
getK() |
int |
getK() |
static int |
getMaxIter() |
static double |
getMinDivisibleClusterSize() |
double |
getMinDivisibleClusterSize() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<java.lang.Object> |
getParam(java.lang.String paramName) |
static java.lang.String |
getPredictionCol() |
static long |
getSeed() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(java.lang.String paramName) |
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
protected static boolean |
isTraceEnabled() |
static IntParam |
k() |
IntParam |
k()
Set the number of clusters to create (k).
|
static BisectingKMeans |
load(java.lang.String path) |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static IntParam |
maxIter() |
static DoubleParam |
minDivisibleClusterSize() |
DoubleParam |
minDivisibleClusterSize() |
static Param<?>[] |
params() |
static Param<java.lang.String> |
predictionCol() |
static void |
save(java.lang.String path) |
static LongParam |
seed() |
static <T> Params |
set(Param<T> param,
T value) |
protected static Params |
set(ParamPair<?> paramPair) |
protected static Params |
set(java.lang.String param,
java.lang.Object value) |
protected static <T> Params |
setDefault(Param<T> param,
T value) |
protected static Params |
setDefault(scala.collection.Seq<ParamPair<?>> paramPairs) |
BisectingKMeans |
setFeaturesCol(java.lang.String value) |
BisectingKMeans |
setK(int value) |
BisectingKMeans |
setMaxIter(int value) |
BisectingKMeans |
setMinDivisibleClusterSize(double value) |
BisectingKMeans |
setPredictionCol(java.lang.String value) |
BisectingKMeans |
setSeed(long value) |
static java.lang.String |
toString() |
StructType |
transformSchema(StructType schema)
:: DeveloperApi ::
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
protected static StructType |
validateAndTransformSchema(StructType schema) |
StructType |
validateAndTransformSchema(StructType schema)
Validates and transforms the input schema.
|
static void |
validateParams() |
static MLWriter |
write() |
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString
public BisectingKMeans(java.lang.String uid)
public BisectingKMeans()
public static BisectingKMeans load(java.lang.String path)
public static java.lang.String toString()
public static Param<?>[] params()
public static void validateParams()
public static java.lang.String explainParam(Param<?> param)
public static java.lang.String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(java.lang.String paramName)
public static Param<java.lang.Object> getParam(java.lang.String paramName)
protected static final Params set(java.lang.String param, java.lang.Object value)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
protected static final <T> T $(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
public static final IntParam maxIter()
public static final int getMaxIter()
public static final Param<java.lang.String> featuresCol()
public static final java.lang.String getFeaturesCol()
public static final LongParam seed()
public static final long getSeed()
public static final Param<java.lang.String> predictionCol()
public static final java.lang.String getPredictionCol()
public static final IntParam k()
public static int getK()
public static final DoubleParam minDivisibleClusterSize()
public static double getMinDivisibleClusterSize()
protected static StructType validateAndTransformSchema(StructType schema)
public static void save(java.lang.String path) throws java.io.IOException
java.io.IOException
public static MLWriter write()
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public BisectingKMeans copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Estimator<BisectingKMeansModel>
extra
- (undocumented)defaultCopy()
public BisectingKMeans setFeaturesCol(java.lang.String value)
public BisectingKMeans setPredictionCol(java.lang.String value)
public BisectingKMeans setK(int value)
public BisectingKMeans setMaxIter(int value)
public BisectingKMeans setSeed(long value)
public BisectingKMeans setMinDivisibleClusterSize(double value)
public BisectingKMeansModel fit(Dataset<?> dataset)
Estimator
fit
in class Estimator<BisectingKMeansModel>
dataset
- (undocumented)public StructType transformSchema(StructType schema)
PipelineStage
Derives the output schema from the input schema.
transformSchema
in class PipelineStage
schema
- (undocumented)public IntParam k()
public int getK()
public DoubleParam minDivisibleClusterSize()
public double getMinDivisibleClusterSize()
public StructType validateAndTransformSchema(StructType schema)
schema
- input schema