Join our newsletter

Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
How To
The Ultimate Guide to Decision Trees for Machine Learning

The decision tree algorithm - used within an ensemble method like the random forest - is one of the most widely used machine learning algorithms in real production settings.

1. Introduction to decision trees

Decision trees are one of the most popular algorithms when it comes to data mining, decision analysis, and artificial intelligence. In this guide, we’ll gently introduce you to decision trees and the reasons why they have gained so much popularity. 

1.1 How would decision trees be described in layman’s terms?

Let’s start with a practical example. Imagine that you’re planning next week’s activities. The things that you’ll get up to will pretty much depend on whether your friends have time and what the weather is like outside. You come up with the following chart:

This chart sets out simple decision rules, which help you to decide what to do next week based on some other data. In this case, it’s your friends’ availability and the weather conditions.

Decision trees do the same. They build up a set of decision rules in the form of a tree structure, which help you to predict an outcome from the input data.

1.2 What are the business use cases of decision trees?

Decision trees mimic human decision-making and can therefore be used in a variety of business settings. Companies often use them to predict future outcomes. For instance:

  1. Which customer will stay loyal and which one will churn? (Classification decision tree)
  2. By how much can we upsell a customer, given their product choices? (Regression decision tree)
  3. Which article should I recommend to my blog readers next? (Classification decision trees)

In general, decision trees are extremely useful tools for mimicking human decision-making, and they have a wide range of applications in both business and personal settings.

1.3 What are the advantages of decision trees for real-life applications?

There are multiple reasons why decision trees are one of the go-to machine learning algorithms in real-life applications: 

  1. Intuitive. It’s easy to comprehend how decision trees make their choices via a set of binary splits (yes/no answers). The decision chart above shows that decision trees learn to predict outcomes in a similar way to humans. Because we can visualize the algorithm and its choices, we can easily understand and explain its inner workings to other non-technical people as well.
  2. Informational. Decision trees offer deep information about how decisions are made. The first split usually indicates the most important feature. This feature importance can be used not only to gain insight into our problem area, but also to transform our practices. Let’s consider an illustrative example: imagine you are building a decision tree to help you understand which prospects will convert to customers. The algorithm determines that the question “Will the lead convert to a customer?” is best answered when we split the leads based on whether or not they have downloaded a technical brochure from our website. Those leads who downloaded the technical specs are much more likely to convert to customers than those who have not. This gives you business information that you did not possess before. You can reshape your lead nurturing activities around that brochure (e.g. send more visitors to the download landing page) to speed up the conversion from prospect to customer, but also to differentiate qualified leads (downloaded brochure) from not qualified more quickly. This also saves your sales agents some valuable time.
  3. Scaling. Because of their inner workings, decision trees handle large datasets and increased volumes of data without halting their prediction speed or losing their high accuracy. This makes them extremely useful for big-data problems.

2. Machine learning approaches to decision trees

Decision trees belong to a class of supervised machine learning algorithms, which are used in both classification (predicts discrete outcome) and regression (predicts continuous numeric outcomes) predictive modeling.

The goal of the algorithm is to predict a target variable from a set of input variables and their attributes. The approach builds a tree structure through a series of binary splits (yes/no) from the root node via branches passing several decision nodes (internal nodes), until we come to leaf nodes. It is here that the prediction is made. Each split partitions the input variables into feature regions, which are used for lower splits.

We can visualize the entire tree structure like this:

2.1 The decision tree algorithm(s)

There is no single decision tree algorithm. Instead, multiple algorithms have been proposed to build decision trees:

  1. ID3: Iterative Dichotomiser 3
  2. C4.5: the successor of ID3
  3. CART: Classification And Regression Tree
  4. CHAID: Chi-square automatic interaction detection 
  5. MARS: multivariate adaptive regression splines

Each new algorithm improves upon the previous ones, with the aim of developing approaches which achieve higher accuracy with noisier or messier data.

In general, we can break down the decision tree algorithm into a series of steps common across different implementations:

  1. Attribute selection - start with the entire dataset and look at every feature or attribute. Look at all of the possible values of that attribute and pick a value which best splits the dataset into different regions. What constitutes ‘a best split’ depends very much on whether we are building a regression tree algorithm or a classification tree algorithm. We’ll expand upon the different methods for finding the best split below.
  2. Split the dataset at the root node of the tree and move to the child nodes in each branch. For each decision node, repeat the attribute selection and value for best split determination. This is a greedy algorithm: it only looks at the best local split (not global optimum) given the attributes in its region to improve the efficiency of building a tree.
  3. Continue iteratively until either:

a) We have grown terminal or leaf nodes so they reach each individual sample (there were no stopping criteria).

b) We reached some stopping criteria. For example, we might have set a maximum depth, which only allows a certain number of splits from the root node to the terminal nodes. Or we might have set a minimum number of samples in each terminal node, in order to prevent terminal nodes from splitting beyond a certain point.

2.2 Decision tree splitting method: how do we find the best split?

Determining the best value of an attribute as a splitting point is equivalent to splitting the dataset to minimize a given cost function. The choice of cost function depends on whether we are solving a classification problem or a regression problem. 

2.2.1 Metrics for decision tree classifiers

In classification problems, the two most popular metrics for determining the splitting point are Gini impurity and information gain:

  1. Gini impurity. As the name suggests, this measures how ‘pure’ our splits are. If a split results in one class being more predominant than another, e.g. 80% of class A and 20% of class B, this means that the split is 80% pure. The algorithm iteratively tries to find percentages like these of independent values, which produce homogenous classes. (P.S. do not confuse Gini impurity with Gini coefficient (also called Gini index), which is a popular econometric measure of inequality).
  2. Information gain. Information gain measures whether or not we lower the system’s entropy after splitting. Entropy, on the other hand, is defined as how chaotic our system is. This might sound abstract, but the concept is rather intuitive. If our decision tree were to split randomly without any structure, we would end up with splits of mixed classes (e.g. 50% class A and 50% class B). Chaos. But if the split results in sorting the classes into their own branches, we’re left with a more structured and less chaotic system. This is very similar to the Gini impurity logic, but information gain does not choose the split according to whether we get pure (structured, less chaotic, less entropic) segmentations after the split, but rather, by how much we improved on the entropy after the split.

When the algorithm traverses all possible values of an attribute, it calculates either the Gini impurity at that point or information gain. The value for the attribute which best minimizes the cost function is used as a split.

2.2.2 Metrics for decision tree regressors

Introduced in the CART algorithm, decision tree regressors use variance reduction as a measure of the best split. Variance reduction indicates how homogenous our nodes are. If a node is extremely homogeneous, its variance (and the variance of its child nodes) will not be as big. The formula for variance is:

The algorithm traverses different values of the independent variables, then picks such a variable and its value which generates the biggest variance reduction after the split.

2.3 Advantages of decision trees

Decision trees offer several benefits:

  1. Interpretable. The algorithm is easy to interpret as a binary (yes/no, true/false) decision on each node. You can also visualize the decision tree to inspect what type of decision rules it has implemented. This puts decision trees under the category of white-box models, aka interpretable models. These are unlike black-box models, (like neural networks), which are difficult or even impossible to interpret and understand. Interpretability offers several business advantages because you can apply the decision tree rules to the business rules in order to improve your performance.
  2. Little to no data preparation. Unlike other algorithms, such as linear regression or logistic regression, decision trees work well with messy data. There is no need to normalize data, create dummy variables for categorical variables, or remove blank or missing values. This greatly shortens the data cleaning time, which is usually the longest chunk in a data science pipeline.
  3. Scale well. Adding new features to the dataset increases the computation time of decision trees on a logarithmic scale. This means that the algorithm can handle large datasets and scales well with increasing data without incurring prohibitive computational costs. 
  4. Handle numerical and categorical data. Machine learning algorithms are usually specialized for either numerical data or categorical data. Decision trees work with both, making them advantageous for multiple real-life production settings where data is usually mixed.
  5. Robust to assumption violations. Linear models perform poorly when their linear assumptions are violated. In contrast, decision trees perform relatively well even when the assumptions in the dataset are only partially fulfilled.

2.4 Disadvantages of decision trees

Like most things, the machine learning approach also has a few disadvantages:

  1. Overfitting. Decision trees overfit very quickly. If you let them grow without a stopping mechanism or a correction mechanism after the tree has been trained, they can split so many times that each leaf is a sample. This means that they’ve literally learned how the training data set looks and suffer from high variance (generalize poorly to novel data). Check the chapter below for practical advice on correcting overfitting.
  2. Non-robust to input data changes. A small change in training data can result in a completely different tree. The overall accuracy might still be high, but the specific decision splits will be totally different.
  3. Biased towards the dominant class. Classification decision trees tend to favor predicting the dominant class in datasets with class imbalance. 

Below, we offer practical tips on how to improve decision trees to mitigate their weaknesses.

2.5 Beyond decision trees: how to improve the model

There are several ways to improve decision trees, each one addressing a specific shortcoming of this machine learning algorithm.

How to avoid overfitting

  1. Minimum samples for leaf split. Determine the minimum number of data points which need to be present at leaf nodes. If an additional split at a leaf node would cause two branches, where at least one branch would have less than the minimum sample of nodes, the leaf node cannot be split further. This prevents the tree from growing too close to samples.
  2. Maximum depth. This is similar to the maximum number of levels that a tree has. You can think of it as the longest number of nodes that a branch can traverse from root to leaf. Setting maximum depth allows you to determine how shallow or deep a tree can get. Shallower trees will be less accurate but will generalize better, while deeper trees will be more accurate on the training set, but have generalization issues with new data.
  3. Pruning. Pruning is an approach that corrects the tree after it has been fitted to the training dataset (unlike maximum depth and minimal leaf samples, which are set before fitting the tree). In pruning, an algorithm starts at the leaf nodes and removes those branches, which, after the removal, do not affect the overall tree accuracy on the test dataset. Effectively, this method preserves a high tree performance, while lowering the complexity (number of branches and splits) of the model.
  4. Ensemble methods: Random forest. Ensemble methods combine multiple trees into an ensemble algorithm. The ensemble uses tree ‘voting’ as a mechanism to determine the true answer. For classification problems, ensemble algorithms pick the mode of the most often predicted class. For regression problems, ensemble algorithms take the average of the trees’ prediction. A special case of ensemble methods is the random forest, one of the most successful and widely used algorithms in artificial intelligence. We wrote an extensive guide on random forest regression - check it out here.
  5. Feature selection or dimensionality reduction. When data is sparse, decision trees overfit. To avoid overfitting on sparse data, either select a subset of features or reduce the dimensionality of your sparse dataset with appropriate algorithms (e.g. PCA).
  6. Boosted trees. Boosted decision trees correct the overfitting by using the standard machine learning method of boosting. Build shallow decision trees (e.g. three levels deep) and with each iteration, build a new decision tree onto the data partition that had the worst splitting metric. This way, you’ll improve the overall performance of your trees by avoiding overfitting.

How to increase tree robustness? Robustness is hard to improve. That’s because decision trees use the greedy algorithm at each split, which finds local - but not global - optima. Instead of the greedy approach, other algorithms have been proposed, such as dual information distance (DID) trees.

How can you correct bias towards the dominant class? To mitigate decision trees’ bias towards predicting the dominant class, make sure to adjust class imbalance before fitting your model. There are three approaches for tackling class imbalance in the preprocessing stage (or data cleaning stage): 

  1. Downsample the majority class.
  2. Upsample the minority class.
  3. Collect more data for the minority class.

3. Decision trees in practice

The ways in which you use decision trees in practice depends on how much you know about the entire data science process.

We recommend that beginners start by modeling data on datasets that have already been collected and cleaned, while experienced data scientists can scale their operations by choosing the right software for the task at hand.

3.1 Beginner projects to try out decision trees

There are multiple datasets to try out decision trees in practice. Among the best ones are:

  1. The classic Titanic survival dataset. Predict whether a passenger or a crew member would have survived the Titanic’s collision with the iceberg.
  2. Create a decision tree to predict whether an employee will leave their company and determine which factors lead to employee attrition.
  3. Determine the likelihood of a bank customer buying a personal loan.

3.2 Production software for advanced data science

Data scientists spend more than 80% of their time on data collection and cleaning. If you want to speed up the entire data pipeline, use software that automates tasks to give you more time for data modeling. 

Keboola offers a platform for data scientists who want to build their own machine learning models. It comes with one-click deployed Jupyter Notebooks, through which all of the modeling can be done using Julia, R, or Python. 

Deep dive into the data science process with this Jupyter Notebook:

  1. Collect the relevant data.
  2. Explore and clean the data to discover patterns.
  3. Train your decision tree model.
  4. Evaluate the model with a variety of metrics.

Want to take it a step further? Keboola can assist you with instrumentalizing your entire data operations pipeline. 
Being a data-centric platform, Keboola also allows you to build your ETL pipelines and orchestrate tasks to get your data ready for machine learning algorithms. Deploy multiple models with different algorithms to version your work and compare which ones perform best. Start building models today with our free trial.

Stay in touch

Download the files
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.