weka.classifiers.trees.lmt
Class LMTNode

java.lang.Object
  extended byweka.classifiers.Classifier
      extended byweka.classifiers.trees.lmt.LogisticBase
          extended byweka.classifiers.trees.lmt.LMTNode
All Implemented Interfaces:
java.lang.Cloneable, OptionHandler, java.io.Serializable, WeightedInstancesHandler

public class LMTNode
extends LogisticBase

Class for logistic model tree structure.

Version:
$Revision: 1.1 $
Author:
Niels Landwehr
See Also:
Serialized Form

Field Summary
 double m_alpha
          Alpha-value (for pruning) at the node
protected  boolean m_fastRegression
          Use heuristic that determines the number of LogitBoost iterations only once in the beginning?
protected  SimpleLinearRegression[][] m_higherRegressions
          Simple regression functions fit by LogitBoost at higher levels in the tree
protected  int m_id
          Node id
protected  boolean m_isLeaf
          True if node is leaf
protected  int m_leafModelNum
          ID of logistic model at leaf
protected  ClassifierSplitModel m_localModel
          The ClassifierSplitModel (for splitting)
protected  int m_minNumInstances
          minimum number of instances at which a node is considered for splitting
protected  ModelSelection m_modelSelection
          ModelSelection object (for splitting)
protected  NominalToBinary m_nominalToBinary
          Filter to convert nominal attributes to binary
protected static int m_numFoldsPruning
          Number of folds for CART pruning
protected  int m_numHigherRegressions
          Number of simple regression functions fit by LogitBoost at higher levels in the tree
 double m_numIncorrectModel
          Weighted number of training examples currently misclassified by the logistic model at the node
 double m_numIncorrectTree
          Weighted number of training examples currently misclassified by the subtree rooted at the node
protected  int m_numInstances
          Number of instances at the node
protected  LMTNode[] m_sons
          Array of children of the node
protected  double m_totalInstanceWeight
          Total number of training instances.
 
Fields inherited from class weka.classifiers.trees.lmt.LogisticBase
m_errorOnProbabilities, m_fixedNumIterations, m_heuristicStop, m_maxIterations, m_numClasses, m_numericData, m_numericDataHeader, m_numFoldsBoosting, m_numRegressions, m_regressions, m_train, m_useCrossValidation, Z_MAX
 
Fields inherited from class weka.classifiers.Classifier
m_Debug
 
Constructor Summary
LMTNode(ModelSelection modelSelection, int numBoostingIterations, boolean fastRegression, boolean errorOnProbabilities, int minNumInstances)
          Constructor for logistic model tree node.
 
Method Summary
 int assignIDs(int lastID)
          Assigns unique IDs to all nodes in the tree
 int assignLeafModelNumbers(int leafCounter)
          Assigns numbers to the logistic regression models at the leaves of the tree
 void buildClassifier(Instances data)
          Method for building a logistic model tree (only called for the root node).
 void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, double totalInstanceWeight)
          Method for building the tree structure.
 void calculateAlphas()
          Updates the alpha field for all nodes.
 void cleanup()
          Cleanup in order to save memory.
 double[] distributionForInstance(Instance instance)
          Returns the class probabilities for an instance given by the logistic model tree.
protected  void dumpTree(int depth, java.lang.StringBuffer text)
          Help method for printing tree structure.
protected  double[][] getCoefficients()
          Returns an array containing the coefficients of the logistic regression function at this node.
protected  double[] getFs(Instance instance)
          Computes the F-values of LogitBoost for an instance from the current logistic model at the node Note that this also takes into account the (partial) logistic model fit at higher levels in the tree.
 java.lang.String getModelParameters()
          Returns a string describing the number of LogitBoost iterations performed at this node, the total number of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number of training examples at this node.
 java.util.Vector getNodes()
          Return a list of all inner nodes in the tree
 void getNodes(java.util.Vector nodeList)
          Fills a list with all inner nodes in the tree
protected  Instances getNumericData(Instances train)
          Returns a numeric version of a set of instances.
 int getNumInnerNodes()
          Method to count the number of inner nodes in the tree
 int getNumLeaves()
          Returns the number of leaves in the tree.
 java.lang.String graph()
          Returns graph describing the tree.
private  void graphTree(java.lang.StringBuffer text)
          Helper function for graph description of tree
 boolean hasModels()
          Returns true if the logistic regression model at this node has changed compared to the one at the parent node.
protected  SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1, SimpleLinearRegression[][] a2)
          Merges two arrays of regression functions into one
 double[] modelDistributionForInstance(Instance instance)
          Returns the class probabilities for an instance according to the logistic model at the node.
 void modelErrors()
          Updates the numIncorrectModel field for all nodes.
 java.lang.String modelsToString()
          Returns a string describing the logistic regression function at the node.
 int numLeaves()
          Returns the number of leaves (normal count).
 int numNodes()
          Returns the number of nodes.
 void prune(double alpha)
          Prunes a logistic model tree using the CART pruning scheme, given a cost-complexity parameter alpha.
 int prune(double[] alphas, double[] errors, Instances test)
          Method for performing one fold in the cross-validation of the cost-complexity parameter.
 java.lang.String toString()
          Returns a description of the logistic model tree (tree structure and logistic models)
 void treeErrors()
          Updates the numIncorrectTree field for all nodes.
protected  int tryLogistic(Instances data)
          Determines the optimum number of LogitBoost iterations to perform by building a standalone logistic regression function on the training data.
protected  void unprune()
          Method to "unprune" a logistic model tree.
 
Methods inherited from class weka.classifiers.trees.lmt.LogisticBase
getBestIteration, getErrorRate, getFs, getMaxIterations, getMeanAbsoluteError, getNumRegressions, getProbs, getUsedAttributes, getWs, getYs, getZ, getZs, initRegressions, logLikelihood, percentAttributesUsed, performBoosting, performBoosting, performBoosting, performBoostingCV, performIteration, probs, selectRegressions, setHeuristicStop, setMaxIterations
 
Methods inherited from class weka.classifiers.Classifier
classifyInstance, debugTipText, forName, getDebug, getOptions, listOptions, makeCopies, setDebug, setOptions
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

m_totalInstanceWeight

protected double m_totalInstanceWeight
Total number of training instances.


m_id

protected int m_id
Node id


m_leafModelNum

protected int m_leafModelNum
ID of logistic model at leaf


m_alpha

public double m_alpha
Alpha-value (for pruning) at the node


m_numIncorrectModel

public double m_numIncorrectModel
Weighted number of training examples currently misclassified by the logistic model at the node


m_numIncorrectTree

public double m_numIncorrectTree
Weighted number of training examples currently misclassified by the subtree rooted at the node


m_minNumInstances

protected int m_minNumInstances
minimum number of instances at which a node is considered for splitting


m_modelSelection

protected ModelSelection m_modelSelection
ModelSelection object (for splitting)


m_nominalToBinary

protected NominalToBinary m_nominalToBinary
Filter to convert nominal attributes to binary


m_higherRegressions

protected SimpleLinearRegression[][] m_higherRegressions
Simple regression functions fit by LogitBoost at higher levels in the tree


m_numHigherRegressions

protected int m_numHigherRegressions
Number of simple regression functions fit by LogitBoost at higher levels in the tree


m_numFoldsPruning

protected static int m_numFoldsPruning
Number of folds for CART pruning


m_fastRegression

protected boolean m_fastRegression
Use heuristic that determines the number of LogitBoost iterations only once in the beginning?


m_numInstances

protected int m_numInstances
Number of instances at the node


m_localModel

protected ClassifierSplitModel m_localModel
The ClassifierSplitModel (for splitting)


m_sons

protected LMTNode[] m_sons
Array of children of the node


m_isLeaf

protected boolean m_isLeaf
True if node is leaf

Constructor Detail

LMTNode

public LMTNode(ModelSelection modelSelection,
               int numBoostingIterations,
               boolean fastRegression,
               boolean errorOnProbabilities,
               int minNumInstances)
Constructor for logistic model tree node.

Parameters:
modelSelection - selection method for local splitting model
numBoostingIterations - sets the numBoostingIterations parameter
fastRegression - sets the fastRegression parameter
Method Detail

buildClassifier

public void buildClassifier(Instances data)
                     throws java.lang.Exception
Method for building a logistic model tree (only called for the root node). Grows an initial logistic model tree and prunes it back using the CART pruning scheme.

Overrides:
buildClassifier in class LogisticBase
Parameters:
data - the training data
Throws:
java.lang.Exception - if something goes wrong

buildTree

public void buildTree(Instances data,
                      SimpleLinearRegression[][] higherRegressions,
                      double totalInstanceWeight)
               throws java.lang.Exception
Method for building the tree structure. Builds a logistic model, splits the node and recursively builds tree for child nodes.

Parameters:
data - the training data passed on to this node
higherRegressions - An array of regression functions produced by LogitBoost at higher levels in the tree. They represent a logistic regression model that is refined locally at this node.
totalInstanceWeight - the total number of training examples
Throws:
java.lang.Exception - if something goes wrong

prune

public void prune(double alpha)
           throws java.lang.Exception
Prunes a logistic model tree using the CART pruning scheme, given a cost-complexity parameter alpha.

Parameters:
alpha - the cost-complexity measure
Throws:
java.lang.Exception

prune

public int prune(double[] alphas,
                 double[] errors,
                 Instances test)
          throws java.lang.Exception
Method for performing one fold in the cross-validation of the cost-complexity parameter. Generates a sequence of alpha-values with error estimates for the corresponding (partially pruned) trees, given the test set of that fold.

Parameters:
alphas - array to hold the generated alpha-values
errors - array to hold the corresponding error estimates
test - test set of that fold (to obtain error estimates)
Throws:
if - something goes wrong
java.lang.Exception

unprune

protected void unprune()
Method to "unprune" a logistic model tree. Sets all leaf-fields to false. Faster than re-growing the tree because the logistic models do not have to be fit again.


tryLogistic

protected int tryLogistic(Instances data)
                   throws java.lang.Exception
Determines the optimum number of LogitBoost iterations to perform by building a standalone logistic regression function on the training data. Used for the heuristic that avoids cross-validating this number again at every node.

Parameters:
data - training instances for the logistic model
Throws:
if - something goes wrong
java.lang.Exception

getNumInnerNodes

public int getNumInnerNodes()
Method to count the number of inner nodes in the tree

Returns:
the number of inner nodes

getNumLeaves

public int getNumLeaves()
Returns the number of leaves in the tree. Leaves are only counted if their logistic model has changed compared to the one of the parent node.

Returns:
the number of leaves

modelErrors

public void modelErrors()
                 throws java.lang.Exception
Updates the numIncorrectModel field for all nodes. This is needed for calculating the alpha-values.

Throws:
java.lang.Exception

treeErrors

public void treeErrors()
Updates the numIncorrectTree field for all nodes. This is needed for calculating the alpha-values.


calculateAlphas

public void calculateAlphas()
                     throws java.lang.Exception
Updates the alpha field for all nodes.

Throws:
java.lang.Exception

mergeArrays

protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1,
                                                 SimpleLinearRegression[][] a2)
Merges two arrays of regression functions into one

Parameters:
a1 - one array
a2 - the other array
Returns:
an array that contains all entries from both input arrays

getNodes

public java.util.Vector getNodes()
Return a list of all inner nodes in the tree

Returns:
the list of nodes

getNodes

public void getNodes(java.util.Vector nodeList)
Fills a list with all inner nodes in the tree

Parameters:
nodeList - the list to be filled

getNumericData

protected Instances getNumericData(Instances train)
                            throws java.lang.Exception
Returns a numeric version of a set of instances. All nominal attributes are replaced by binary ones, and the class variable is replaced by a pseudo-class variable that is used by LogitBoost.

Overrides:
getNumericData in class LogisticBase
Throws:
java.lang.Exception

getFs

protected double[] getFs(Instance instance)
                  throws java.lang.Exception
Computes the F-values of LogitBoost for an instance from the current logistic model at the node Note that this also takes into account the (partial) logistic model fit at higher levels in the tree.

Overrides:
getFs in class LogisticBase
Parameters:
instance - the instance
Returns:
the array of F-values
Throws:
java.lang.Exception

hasModels

public boolean hasModels()
Returns true if the logistic regression model at this node has changed compared to the one at the parent node.

Returns:
whether it has changed

modelDistributionForInstance

public double[] modelDistributionForInstance(Instance instance)
                                      throws java.lang.Exception
Returns the class probabilities for an instance according to the logistic model at the node.

Parameters:
instance - the instance
Returns:
the array of probabilities
Throws:
java.lang.Exception

distributionForInstance

public double[] distributionForInstance(Instance instance)
                                 throws java.lang.Exception
Returns the class probabilities for an instance given by the logistic model tree.

Overrides:
distributionForInstance in class LogisticBase
Parameters:
instance - the instance
Returns:
the array of probabilities
Throws:
java.lang.Exception - if distribution can't be computed successfully

numLeaves

public int numLeaves()
Returns the number of leaves (normal count).

Returns:
the number of leaves

numNodes

public int numNodes()
Returns the number of nodes.

Returns:
the number of nodes

toString

public java.lang.String toString()
Returns a description of the logistic model tree (tree structure and logistic models)

Overrides:
toString in class LogisticBase
Returns:
describing string

getModelParameters

public java.lang.String getModelParameters()
Returns a string describing the number of LogitBoost iterations performed at this node, the total number of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number of training examples at this node.

Returns:
the describing string

dumpTree

protected void dumpTree(int depth,
                        java.lang.StringBuffer text)
                 throws java.lang.Exception
Help method for printing tree structure.

Throws:
java.lang.Exception - if something goes wrong

assignIDs

public int assignIDs(int lastID)
Assigns unique IDs to all nodes in the tree


assignLeafModelNumbers

public int assignLeafModelNumbers(int leafCounter)
Assigns numbers to the logistic regression models at the leaves of the tree


getCoefficients

protected double[][] getCoefficients()
Returns an array containing the coefficients of the logistic regression function at this node.

Overrides:
getCoefficients in class LogisticBase
Returns:
the array of coefficients, first dimension is the class, second the attribute.

modelsToString

public java.lang.String modelsToString()
Returns a string describing the logistic regression function at the node.


graph

public java.lang.String graph()
                       throws java.lang.Exception
Returns graph describing the tree.

Throws:
java.lang.Exception - if something goes wrong

graphTree

private void graphTree(java.lang.StringBuffer text)
                throws java.lang.Exception
Helper function for graph description of tree

Throws:
java.lang.Exception - if something goes wrong

cleanup

public void cleanup()
Cleanup in order to save memory.

Overrides:
cleanup in class LogisticBase