What's the idea of Decision Tree Classifier?
The basic intuition behind a decision tree is to map out all possible decision paths in the form of a tree. It can be used for classification and regression (note). In this post, let's try to understand the classifier.
Suppose that we have a dataset like in the figure below[ref, Table 1.2],
An example of dataset .
A decision tree we want.
There are many algorithms which can help us make a tree like above, in Machine Learning, we usually use:
- ID3 (Iterative Dichotomiser): uses information gain / entropy.
- CART (Classification And Regression Tree): uses Gini impurity.
Some basic concepts
- Splitting: It is a process of dividing a node into two or more sub-nodes.
- Pruning: When we remove sub-nodes of a decision node, this process is called pruning.
- Parent node and Child Node: A node, which is divided into sub-nodes is called parent node of sub-nodes where as sub-nodes are the child of parent node.
ID3 algorithm
To check the disorder at current node (let's say , parent node), we calculate its entropy with,
where the number of classes and is the probability of class in .
If entropy at this node is pure (there is only 1 class or the majority is 1 class) or it meets the stopping conditions, we stop splitting at this node. Otherwise, go to the next step.
Calculate the information gain (IG) after splitting node on each attribute (for example, consider attribute ). The attribute w.r.t. the biggest IG will be chosen!
where number of different properties in and is the propability of property in .
After splitting, we have new child nodes. Each of them becomes a new parent node in the next step. Go back to step 1.
How we know we can split the dataset base on the Outlook attribute instead of the others (Temperature, Humidity, Windy)? We calculate the information gain after splitting on each attribute. Itβs the information which can increase the level of certainty after splitting. The highest one will be chosen (after this section, you will see that the Outlook attribute has the highest information gain).
In order to calculate the information gain, we need "entropy" which is the amount of information disorder or the amount of randomness in the data.
At the beginning, entropy before split
() shows us the disorder status of the whole dataset . If contains only Yes
, has no disorder or it's pure (. If the amount of Yes
and No
in is equal, has the highest disorder ().
An illustration of entropy with different proportions of Yes/No in .
At each node, we need to calculate again its entropy (corresponding to the number of Yes
and No
in this node.). We prefer the lowest entropy, of course! How can we calculate entropy of each node? More specifically, how to calculate ?
where the number of classes (node has 2 classes, Yes
and No
), is the probability of class in .
Graph of in the case of 2 classes. Max is 1.
In this case we use (binary logarithm) to obtain the maximum and we also use a convention in which . There are other documents using (natural logarithm) instead.
On node , we have,
We see that, is not pure but it's also not totally disordered.
The frequency of classes in S.
Because we are considering to split on (Outlook) and has 3 different properties which are Sunny, Overcast and Rainy. Corresponding to these properties, we have different sizes of Yes
/No
(Different nodes having different sizes of data but their total is equal to the size of which is their "parent" node.). That's why we need to calculate the weighted entropy (weighted entropy after split
).
where number of different properties in and is the propability of property in . Therefore, the information gain if split on is,
If we split S on Outlook (O), there will be 3 branches.
For example, we consider branch (Sunny), it has and entropy at this node, is calculated as
Only consider branch Sunny ().
Thus, the information gain after splitting on is,
With the same method, we can calculate the information gain after splitting on other attributes (Temperature, Windy, Humidity) and get,
Dataset is split into different ways.
We can see that, the winner is Outlook with the highest information gain. We split on that attribute first!
Dataset is split on Outlook.
How about 3 others remaining attributes (Temperature, Humidity, Windy), which one to be chosen next? Especially on branches Suuny and Humidity because on branch Overcast, this node is pure (all are Yes
), we don't need to split any more.
There are remaining Temperature, Humidity, Windy. Which attribute will be chosen next?
We repeat the steps again, for example, on the branch (Sunny), we calculate IG after splitting on each attribute Temperature (T), Humidity (H) or Windy (W). Other words, we need to calculate , and and then compare them to find the best one. Let's consider (Humidity) as an example,
Nodes and are pure, that's why their entropy are .
Consider branch and attribute Windy (W).
Similarly, we calculate , and we see that is the biggest one! So we choose (Windy) to split at node . On the branch (Rainy), the biggest information gain after splitting is on (Humidity).
From now, if we have a new input which contains information about Outlook, Temperature, Humidity and Windy, we go from the top of the tree and choose an appropriate branch to get the decision Yes
or No
.
CART algorithm
The difference between two algorithms is the difference between and .
To check the disorder at current node (let's say , parent node), we calculate its Giny Impurity with,
where the number of classes in and is the probability of class in .
If entropy at this node is pure (there is only 1 class or the majority is 1 class) or it meets the stopping conditions, we stop splitting at this node. Otherwise, go to the next step.
Calculate the Gini Gain (GG) after splitting node on each attribute (for example, consider attribute ). The attribute w.r.t. the biggest GG will be chosen!
where number of different properties in and is the propability of property in .
After splitting, we have new child nodes. Each of them becomes a new parent node in the next step. Go back to step 1.
It's quite the same to the ID3 algorithm except a truth that it's based on the definition of Gini impurity instead of Entropy. Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it was randomly labeled according to the distribution of labels in the subset.
At every nonleaf node (which isn't pure), we have to answer a question "Which attribute we should choose to split that node?" We calculate the Gini gain for each split based on the attribute we are going to use. This Gini gain is quite the same as Information gain. The highest one will be chosen.
The Gini Impurity at node is calculated as,
where the number of classes in , is the probability of class in . will be the best!
On node , we have,
The frequency of classes in S.
Similarly to the information gain, we can calculate Gini Gain () after splitting on the property with,
where number of different properties in and is the propability of property in .
If we split S on Outlook (O), there will be 3 branches.
Apply above equation, we calculate all GG if splitting on each property and get,
The same for (Humidity), (Temperature) and (Windy). Keep going the same arguments as in the section ID3 in detail, we will get the final tree. The difference between two algorithms is the difference between and .
Gini Impurity or Entropy?
Some points:[ref]
- Most of the time, they lead to similar trees.[ref]
- Gini impurity is slightly faster.[ref]
- Gini impurity tends to isolate the most frequent class in its own branch of the tree, while entropy tends to produce slightly more balanced trees.
Good / Bad of Decision Tree?
Some highlight advantages of Decision Tree Classifier:[ref]
- Can be used for regression or classification.
- Can be displayed graphically.
- Highly interpretable.
- Can be specified as a series of rules, and more closely approximate human decision-making than other models.
- Prediction is fast.
- Features don't need scaling.
- Automatically learns feature interactions.
- Tends to ignore irrelevant features.
- Non-parametric (will outperform linear models if relationship between features and response is highly non-linear).
Its disadvantages:
- Performance is (generally) not competitive with the best supervised learning methods.
- Can easily overfit the training data (tuning is required).
- Small variations in the data can result in a completely different tree (high variance).
- Recursive binary splitting makes "locally optimal" decisions that may not result in a globally optimal tree.
- Doesn't work well with unbalanced or small datasets.
When to stop?
If the number of features are too large, we'll have a very large tree! Even, it easily leads to an overfitting problem. How to avoid them?
- Pruning: removing the branches that make use of features having low importance.
- Set a minimum number of training input to use on each leaf. If it doesn't satisfy, we remove this leaf. In scikit-learn, use
min_samples_split
. - Set the maximum depth of the tree. In scikit-learn, use
max_depth
.
When we need to use Decision Tree?
- When explainability between variable is prioritised over accuracy. Otherwise, we tend to use Random Forest.
- When the data is more non-parametric in nature.
- When we want a simple model.
- When entire dataset and features can be used
- When we have limited computational power
- When we are not worried about accuracy on future datasets.
- When we are not worried about accuracy on future datasets.
Using Decision Tree Classifier with Scikit-learn
Load and create
Load the library,
from sklearn.tree import DecisionTreeClassifier
Create a decision tree (other parameters):
# The Gini impurity (default)
clf = DecisionTreeClassifier() # criterion='gini'
# The information gain (ID3)
clf = DecisionTreeClassifier(criterion='entropy')
An example,
from sklearn import tree
X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
# predict
clf.predict([[2., 2.]])
# probability of each class
clf.predict_proba([[2., 2.]])
array([1])
array([[0., 1.]])
Plot and Save plots
Plot the tree (You may need to install Graphviz first. Don't forget to add its installed folder to $path
),
from IPython.display import Image
import pydotplus
dot_data = tree.export_graphviz(clf, out_file=None,
rounded=True,
filled=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
Save the tree (follows the codes in "plot the tree")
graph.write_pdf("tree.pdf") # to pdf
graph.write_png("thi.png") # to png
References
- Scikit-learn. Decision Tree CLassifier official doc.
- Saed Sayad. Decision Tree - Classification.
- Tiep Vu. BΓ i 34: Decision Trees (1): Iterative Dichotomiser 3.
- Brian Ambielli. Information Entropy and Information Gain.
- Brian Ambielli. Gini Impurity (With Examples).
- AurΓ©lien GΓ©ron. Hands-on Machine Learning with Scikit-Learn and TensorFlow, chapter 6.
π¬ Comments