public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler, Aggregateable<SGD>
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
| Modifier and Type | Field and Description |
|---|---|
static int |
EPSILON_INSENSITIVE
The epsilon insensitive loss function
|
static int |
HINGE
the hinge loss function.
|
static int |
HUBER
The Huber loss function
|
static int |
LOGLOSS
the log loss function.
|
static int |
SQUAREDLOSS
the squared loss function.
|
static Tag[] |
TAGS_SELECTION
Loss functions to choose from
|
BATCH_SIZE_DEFAULT, NUM_DECIMAL_PLACES_DEFAULT| Constructor and Description |
|---|
SGD() |
| Modifier and Type | Method and Description |
|---|---|
SGD |
aggregate(SGD toAggregate)
Aggregate an object with this one
|
void |
buildClassifier(Instances data)
Method for building the classifier.
|
double[] |
distributionForInstance(Instance inst)
Computes the distribution for a given instance
|
java.lang.String |
dontNormalizeTipText()
Returns the tip text for this property
|
java.lang.String |
dontReplaceMissingTipText()
Returns the tip text for this property
|
java.lang.String |
epochsTipText()
Returns the tip text for this property
|
java.lang.String |
epsilonTipText()
Returns the tip text for this property
|
void |
finalizeAggregation()
Call to complete the aggregation process.
|
Capabilities |
getCapabilities()
Returns default capabilities of the classifier.
|
boolean |
getDontNormalize()
Get whether normalization has been turned off.
|
boolean |
getDontReplaceMissing()
Get whether global replacement of missing values has been disabled.
|
int |
getEpochs()
Get current number of epochs
|
double |
getEpsilon()
Get the epsilon threshold on the error for epsilon insensitive and Huber
loss functions
|
double |
getLambda()
Get the current value of lambda
|
double |
getLearningRate()
Get the learning rate.
|
SelectedTag |
getLossFunction()
Get the current loss function.
|
java.lang.String[] |
getOptions()
Gets the current settings of the classifier.
|
java.lang.String |
getRevision()
Returns the revision string.
|
double[] |
getWeights() |
java.lang.String |
globalInfo()
Returns a string describing classifier
|
java.lang.String |
lambdaTipText()
Returns the tip text for this property
|
java.lang.String |
learningRateTipText()
Returns the tip text for this property
|
java.util.Enumeration<Option> |
listOptions()
Returns an enumeration describing the available options.
|
java.lang.String |
lossFunctionTipText()
Returns the tip text for this property
|
static void |
main(java.lang.String[] args)
Main method for testing this class.
|
void |
reset()
Reset the classifier.
|
void |
setDontNormalize(boolean m)
Turn normalization off/on.
|
void |
setDontReplaceMissing(boolean m)
Turn global replacement of missing values off/on.
|
void |
setEpochs(int e)
Set the number of epochs to use
|
void |
setEpsilon(double e)
Set the epsilon threshold on the error for epsilon insensitive and Huber
loss functions
|
void |
setLambda(double lambda)
Set the value of lambda to use
|
void |
setLearningRate(double lr)
Set the learning rate.
|
void |
setLossFunction(SelectedTag function)
Set the loss function to use.
|
void |
setOptions(java.lang.String[] options)
Parses a given list of options.
|
java.lang.String |
toString()
Prints out the classifier.
|
void |
updateClassifier(Instance instance)
Updates the classifier with the given instance.
|
getSeed, seedTipText, setSeedbatchSizeTipText, classifyInstance, debugTipText, distributionsForInstances, doNotCheckCapabilitiesTipText, forName, getBatchSize, getDebug, getDoNotCheckCapabilities, getNumDecimalPlaces, implementsMoreEfficientBatchPrediction, makeCopies, makeCopy, numDecimalPlacesTipText, postExecution, preExecution, run, runClassifier, setBatchSize, setDebug, setDoNotCheckCapabilities, setNumDecimalPlacespublic static final int HINGE
public static final int LOGLOSS
public static final int SQUAREDLOSS
public static final int EPSILON_INSENSITIVE
public static final int HUBER
public static final Tag[] TAGS_SELECTION
public Capabilities getCapabilities()
getCapabilities in interface ClassifiergetCapabilities in interface CapabilitiesHandlergetCapabilities in class AbstractClassifierCapabilitiespublic java.lang.String epsilonTipText()
public void setEpsilon(double e)
e - the value of epsilon to usepublic double getEpsilon()
public java.lang.String lambdaTipText()
public void setLambda(double lambda)
lambda - the value of lambda to usepublic double getLambda()
public void setLearningRate(double lr)
lr - the learning rate to use.public double getLearningRate()
public java.lang.String learningRateTipText()
public java.lang.String epochsTipText()
public void setEpochs(int e)
e - the number of epochs to usepublic int getEpochs()
public void setDontNormalize(boolean m)
m - true if normalization is to be disabled.public boolean getDontNormalize()
public java.lang.String dontNormalizeTipText()
public void setDontReplaceMissing(boolean m)
m - true if global replacement of missing values is to be turned off.public boolean getDontReplaceMissing()
public java.lang.String dontReplaceMissingTipText()
public void setLossFunction(SelectedTag function)
function - the loss function to use.public SelectedTag getLossFunction()
public java.lang.String lossFunctionTipText()
public java.util.Enumeration<Option> listOptions()
listOptions in interface OptionHandlerlistOptions in class RandomizableClassifierpublic void setOptions(java.lang.String[] options)
throws java.lang.Exception
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
setOptions in interface OptionHandlersetOptions in class RandomizableClassifieroptions - the list of options as an array of stringsjava.lang.Exception - if an option is not supportedpublic java.lang.String[] getOptions()
getOptions in interface OptionHandlergetOptions in class RandomizableClassifierpublic java.lang.String globalInfo()
public void reset()
public void buildClassifier(Instances data) throws java.lang.Exception
buildClassifier in interface Classifierdata - the set of training instances.java.lang.Exception - if the classifier can't be built successfully.public void updateClassifier(Instance instance) throws java.lang.Exception
updateClassifier in interface UpdateableClassifierinstance - the new training instance to include in the modeljava.lang.Exception - if the instance could not be incorporated in the
model.public double[] distributionForInstance(Instance inst) throws java.lang.Exception
distributionForInstance in interface ClassifierdistributionForInstance in class AbstractClassifierinstance - the instance for which distribution is computedjava.lang.Exception - if the distribution can't be computed successfullypublic double[] getWeights()
public java.lang.String toString()
toString in class java.lang.Objectpublic java.lang.String getRevision()
getRevision in interface RevisionHandlergetRevision in class AbstractClassifierpublic SGD aggregate(SGD toAggregate) throws java.lang.Exception
aggregate in interface Aggregateable<SGD>toAggregate - the object to aggregatejava.lang.Exception - if the supplied object can't be aggregated for some
reasonpublic void finalizeAggregation()
throws java.lang.Exception
finalizeAggregation in interface Aggregateable<SGD>java.lang.Exception - if the aggregation can't be finalized for some reasonpublic static void main(java.lang.String[] args)