weka.classifiers.meta
Class ThresholdSelector

java.lang.Object
  extended byweka.classifiers.Classifier
      extended byweka.classifiers.meta.ThresholdSelector
All Implemented Interfaces:
java.lang.Cloneable, Drawable, OptionHandler, java.io.Serializable

public class ThresholdSelector
extends Classifier
implements OptionHandler, Drawable

Class for selecting a threshold on a probability output by a distribution classifier. The threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).

Valid options are:

-C num
The class for which threshold is determined. Valid values are: 1, 2 (for first and second classes, respectively), 3 (for whichever class is least frequent), 4 (for whichever class value is most frequent), and 5 (for the first class named any of "yes","pos(itive)", "1", or method 3 if no matches). (default 5).

-W classname
Specify the full class name of the base classifier.

-X num
Number of folds used for cross validation. If just a hold-out set is used, this determines the size of the hold-out set (default 3).

-R integer
Sets whether confidence range correction is applied. This can be used to ensure the confidences range from 0 to 1. Use 0 for no range correction, 1 for correction based on the min/max values seen during threshold selection (default 0).

-S seed
Random number seed (default 1).

-E integer
Sets the evaluation mode. Use 0 for evaluation using cross-validation, 1 for evaluation using hold-out set, and 2 for evaluation on the training data (default 1).

Options after -- are passed to the designated sub-classifier.

Version:
$Revision: 1.29 $
Author:
Eibe Frank (eibe@cs.waikato.ac.nz)
See Also:
Serialized Form

Field Summary
static int EVAL_CROSS_VALIDATION
           
static int EVAL_TRAINING_SET
           
static int EVAL_TUNED_SPLIT
           
protected  double m_BestThreshold
          The threshold that lead to the best performance
protected  double m_BestValue
          The best value that has been observed
protected  Classifier m_Classifier
          The generated base classifier
protected  int m_ClassMode
          Method to determine which class to optimize for
protected  int m_DesignatedClass
          Designated class value, determined during building
protected  int m_EvalMode
          The evaluation mode
protected  double m_HighThreshold
          The upper threshold used as the basis of correction
protected  double m_LowThreshold
          The lower threshold used as the basis of correction
protected  int m_NumXValFolds
          The number of folds used in cross-validation
protected  int m_RangeMode
          The range correction mode
protected  int m_Seed
          Random number seed
protected static double MIN_VALUE
          The minimum value for the criterion.
static int OPTIMIZE_0
           
static int OPTIMIZE_1
           
static int OPTIMIZE_LFREQ
           
static int OPTIMIZE_MFREQ
           
static int OPTIMIZE_POS_NAME
           
static int RANGE_BOUNDS
           
static int RANGE_NONE
           
static Tag[] TAGS_EVAL
           
static Tag[] TAGS_OPTIMIZE
           
static Tag[] TAGS_RANGE
           
 
Fields inherited from class weka.classifiers.Classifier
m_Debug
 
Fields inherited from interface weka.core.Drawable
BayesNet, NOT_DRAWABLE, TREE
 
Constructor Summary
ThresholdSelector()
           
 
Method Summary
 void buildClassifier(Instances instances)
          Generates the classifier.
private  boolean checkForInstance(Instances data)
          Checks whether instance of designated class is in subset.
 java.lang.String classifierTipText()
           
 java.lang.String designatedClassTipText()
           
 double[] distributionForInstance(Instance instance)
          Calculates the class membership probabilities for the given test instance.
 java.lang.String evaluationModeTipText()
           
protected  void findThreshold(FastVector predictions)
          Finds the best threshold, this implementation searches for the highest FMeasure.
 Classifier getClassifier()
          Get the Classifier used as the classifier.
protected  java.lang.String getClassifierSpec()
          Gets the classifier specification string, which contains the class name of the classifier and any options to the classifier
 SelectedTag getDesignatedClass()
          Gets the method to determine which class value to optimize.
 SelectedTag getEvaluationMode()
          Gets the evaluation mode used.
 int getNumXValFolds()
          Get the number of folds used for cross-validation.
 java.lang.String[] getOptions()
          Gets the current settings of the Classifier.
protected  FastVector getPredictions(Instances instances, int mode, int numFolds)
          Collects the classifier predictions using the specified evaluation method.
 SelectedTag getRangeCorrection()
          Gets the confidence range correction mode used.
 int getSeed()
          Gets the random number seed.
 java.lang.String globalInfo()
           
 java.lang.String graph()
          Returns graph describing the classifier (if possible).
 int graphType()
          Returns the type of graph this classifier represents.
 java.util.Enumeration listOptions()
          Returns an enumeration describing the available options.
static void main(java.lang.String[] argv)
          Main method for testing this class.
 java.lang.String numXValFoldsTipText()
           
 java.lang.String rangeCorrectionTipText()
           
 java.lang.String seedTipText()
           
 void setClassifier(Classifier newClassifier)
          Set the Classifier for which threshold is set.
 void setDesignatedClass(SelectedTag newMethod)
          Sets the method to determine which class value to optimize.
 void setEvaluationMode(SelectedTag newMethod)
          Sets the evaluation mode used.
 void setNumXValFolds(int newNumFolds)
          Set the number of folds used for cross-validation.
 void setOptions(java.lang.String[] options)
          Parses a given list of options.
 void setRangeCorrection(SelectedTag newMethod)
          Sets the confidence range correction mode used.
 void setSeed(int seed)
          Sets the seed for random number generation.
 java.lang.String toString()
          Returns description of the cross-validated classifier.
 
Methods inherited from class weka.classifiers.Classifier
classifyInstance, debugTipText, forName, getDebug, makeCopies, setDebug
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

RANGE_NONE

public static final int RANGE_NONE
See Also:
Constant Field Values

RANGE_BOUNDS

public static final int RANGE_BOUNDS
See Also:
Constant Field Values

TAGS_RANGE

public static final Tag[] TAGS_RANGE

EVAL_TRAINING_SET

public static final int EVAL_TRAINING_SET
See Also:
Constant Field Values

EVAL_TUNED_SPLIT

public static final int EVAL_TUNED_SPLIT
See Also:
Constant Field Values

EVAL_CROSS_VALIDATION

public static final int EVAL_CROSS_VALIDATION
See Also:
Constant Field Values

TAGS_EVAL

public static final Tag[] TAGS_EVAL

OPTIMIZE_0

public static final int OPTIMIZE_0
See Also:
Constant Field Values

OPTIMIZE_1

public static final int OPTIMIZE_1
See Also:
Constant Field Values

OPTIMIZE_LFREQ

public static final int OPTIMIZE_LFREQ
See Also:
Constant Field Values

OPTIMIZE_MFREQ

public static final int OPTIMIZE_MFREQ
See Also:
Constant Field Values

OPTIMIZE_POS_NAME

public static final int OPTIMIZE_POS_NAME
See Also:
Constant Field Values

TAGS_OPTIMIZE

public static final Tag[] TAGS_OPTIMIZE

m_Classifier

protected Classifier m_Classifier
The generated base classifier


m_HighThreshold

protected double m_HighThreshold
The upper threshold used as the basis of correction


m_LowThreshold

protected double m_LowThreshold
The lower threshold used as the basis of correction


m_BestThreshold

protected double m_BestThreshold
The threshold that lead to the best performance


m_BestValue

protected double m_BestValue
The best value that has been observed


m_NumXValFolds

protected int m_NumXValFolds
The number of folds used in cross-validation


m_Seed

protected int m_Seed
Random number seed


m_DesignatedClass

protected int m_DesignatedClass
Designated class value, determined during building


m_ClassMode

protected int m_ClassMode
Method to determine which class to optimize for


m_EvalMode

protected int m_EvalMode
The evaluation mode


m_RangeMode

protected int m_RangeMode
The range correction mode


MIN_VALUE

protected static final double MIN_VALUE
The minimum value for the criterion. If threshold adjustment yields less than that, the default threshold of 0.5 is used.

See Also:
Constant Field Values
Constructor Detail

ThresholdSelector

public ThresholdSelector()
Method Detail

getPredictions

protected FastVector getPredictions(Instances instances,
                                    int mode,
                                    int numFolds)
                             throws java.lang.Exception
Collects the classifier predictions using the specified evaluation method.

Parameters:
instances - the set of Instances to generate predictions for.
mode - the evaluation mode.
numFolds - the number of folds to use if not evaluating on the full training set.
Returns:
a FastVector containing the predictions.
Throws:
java.lang.Exception - if an error occurs generating the predictions.

findThreshold

protected void findThreshold(FastVector predictions)
Finds the best threshold, this implementation searches for the highest FMeasure. If no FMeasure higher than MIN_VALUE is found, the default threshold of 0.5 is used.

Parameters:
predictions - a FastVector containing the predictions.

listOptions

public java.util.Enumeration listOptions()
Returns an enumeration describing the available options.

Specified by:
listOptions in interface OptionHandler
Overrides:
listOptions in class Classifier
Returns:
an enumeration of all the available options.

setOptions

public void setOptions(java.lang.String[] options)
                throws java.lang.Exception
Parses a given list of options. Valid options are:

-C num
The class for which threshold is determined. Valid values are: 1, 2 (for first and second classes, respectively), 3 (for whichever class is least frequent), 4 (for whichever class value is most frequent), and 5 (for the first class named any of "yes","pos(itive)", "1", or method 3 if no matches). (default 3).

-W classname
Specify the full class name of classifier to perform cross-validation selection on.

-X num
Number of folds used for cross validation. If just a hold-out set is used, this determines the size of the hold-out set (default 3).

-R integer
Sets whether confidence range correction is applied. This can be used to ensure the confidences range from 0 to 1. Use 0 for no range correction, 1 for correction based on the min/max values seen during threshold selection (default 0).

-S seed
Random number seed (default 1).

-E integer
Sets the evaluation mode. Use 0 for evaluation using cross-validation, 1 for evaluation using hold-out set, and 2 for evaluation on the training data (default 1).

Options after -- are passed to the designated sub-classifier.

Specified by:
setOptions in interface OptionHandler
Overrides:
setOptions in class Classifier
Parameters:
options - the list of options as an array of strings
Throws:
java.lang.Exception - if an option is not supported

getOptions

public java.lang.String[] getOptions()
Gets the current settings of the Classifier.

Specified by:
getOptions in interface OptionHandler
Overrides:
getOptions in class Classifier
Returns:
an array of strings suitable for passing to setOptions

buildClassifier

public void buildClassifier(Instances instances)
                     throws java.lang.Exception
Generates the classifier.

Specified by:
buildClassifier in class Classifier
Parameters:
instances - set of instances serving as training data
Throws:
java.lang.Exception - if the classifier has not been generated successfully

checkForInstance

private boolean checkForInstance(Instances data)
                          throws java.lang.Exception
Checks whether instance of designated class is in subset.

Throws:
java.lang.Exception

distributionForInstance

public double[] distributionForInstance(Instance instance)
                                 throws java.lang.Exception
Calculates the class membership probabilities for the given test instance.

Overrides:
distributionForInstance in class Classifier
Parameters:
instance - the instance to be classified
Returns:
predicted class probability distribution
Throws:
java.lang.Exception - if instance could not be classified successfully

globalInfo

public java.lang.String globalInfo()
Returns:
a description of the classifier suitable for displaying in the explorer/experimenter gui

designatedClassTipText

public java.lang.String designatedClassTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

getDesignatedClass

public SelectedTag getDesignatedClass()
Gets the method to determine which class value to optimize. Will be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, OPTIMIZE_POS_NAME.

Returns:
the class selection mode.

setDesignatedClass

public void setDesignatedClass(SelectedTag newMethod)
Sets the method to determine which class value to optimize. Will be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ, OPTIMIZE_POS_NAME.

Parameters:
newMethod - the new class selection mode.

evaluationModeTipText

public java.lang.String evaluationModeTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setEvaluationMode

public void setEvaluationMode(SelectedTag newMethod)
Sets the evaluation mode used. Will be one of EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION

Parameters:
newMethod - the new evaluation mode.

getEvaluationMode

public SelectedTag getEvaluationMode()
Gets the evaluation mode used. Will be one of EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION

Returns:
the evaluation mode.

rangeCorrectionTipText

public java.lang.String rangeCorrectionTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setRangeCorrection

public void setRangeCorrection(SelectedTag newMethod)
Sets the confidence range correction mode used. Will be one of RANGE_NONE, or RANGE_BOUNDS

Parameters:
newMethod - the new correciton mode.

getRangeCorrection

public SelectedTag getRangeCorrection()
Gets the confidence range correction mode used. Will be one of RANGE_NONE, or RANGE_BOUNDS

Returns:
the confidence correction mode.

seedTipText

public java.lang.String seedTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setSeed

public void setSeed(int seed)
Sets the seed for random number generation.

Parameters:
seed - the random number seed

getSeed

public int getSeed()
Gets the random number seed.

Returns:
the random number seed

numXValFoldsTipText

public java.lang.String numXValFoldsTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

getNumXValFolds

public int getNumXValFolds()
Get the number of folds used for cross-validation.

Returns:
the number of folds used for cross-validation.

setNumXValFolds

public void setNumXValFolds(int newNumFolds)
Set the number of folds used for cross-validation.

Parameters:
newNumFolds - the number of folds used for cross-validation.

classifierTipText

public java.lang.String classifierTipText()
Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setClassifier

public void setClassifier(Classifier newClassifier)
Set the Classifier for which threshold is set.

Parameters:
newClassifier - the Classifier to use.

getClassifier

public Classifier getClassifier()
Get the Classifier used as the classifier.

Returns:
the classifier used as the classifier

getClassifierSpec

protected java.lang.String getClassifierSpec()
Gets the classifier specification string, which contains the class name of the classifier and any options to the classifier

Returns:
the classifier string.

graphType

public int graphType()
Returns the type of graph this classifier represents.

Specified by:
graphType in interface Drawable
Returns:
the type of graph representing the object

graph

public java.lang.String graph()
                       throws java.lang.Exception
Returns graph describing the classifier (if possible).

Specified by:
graph in interface Drawable
Returns:
the graph of the classifier in dotty format
Throws:
java.lang.Exception - if the classifier cannot be graphed

toString

public java.lang.String toString()
Returns description of the cross-validated classifier.

Returns:
description of the cross-validated classifier as a string

main

public static void main(java.lang.String[] argv)
Main method for testing this class.

Parameters:
argv - the options