1 Introduction
In recent decades, machine learning (ML) has gained significant popularity in various fields. However, the widespread adoption of black-box ML models, such as neural networks and ensemble models, has led to growing concerns about their interpretability. This lack of interpretability has triggered skepticism and criticism, particularly in decision-based applications. For instance, in fields like medical diagnostics or treatment choice, without straightforward and concise interpretability the model could lead to erroneous diagnoses and potentially harmful treatment decisions. Knowledge distillation [2, 8, 23, 20, 22, 1] provides a way to interpret the black-box ML model through a transparent model, following a teacher-student architecture [9]. Knowledge about the data is distilled from the teacher model (black-box ML model) to train the student model (transparent model). As a result, the student model (transparent model) inherits the teacher model’s knowledge about the complex structure of the data and the underlying mechanisms of the domain question, enabling it to achieve both high interpretability and strong predictive performance. [17] employed several simple models, including linear models and decision trees, as transparent models. Similarly, [15] utilized kernel methods and local linear approaches to construct the transparent model. In this study, we focus on the decision tree [11, 6, 4, 14, 21, 5] which emerges as an ideal transparent model for two reasons. First, it is inherently interpretable. Second, it possesses the capacity to capture complex data structures. Several studies have employed decision tree as transparent model alongside knowledge distillation. [6] explored the distillation of a neural network into a soft decision tree. [4] discussed the decision tree model in interpreting deep reinforcement model. [19] used decision tree for explaining data in the field of e-commerce. However, none of these studies considered the stability of decision trees constructed through knowledge distillation. The interpretability of decision tree relies heavily on the stability of its structure, which may be sensitive to the specific datasets used for training. Interpretations may become questionable if minor changes in the training data significantly affect the tree’s structure. Given that the training data is generated randomly through the knowledge distillation process, ensuring the stability of the tree’s structure becomes a key challenge to address. [27] explored tree structure stability in knowledge distillation, while their study focused on a single splitting criterion and did not provide conclusive conditions under which the tree structure (or split) converge.
We refer to the decision tree generated from the knowledge distillation process as the “knowledge distillation decision tree” (KDDT). In this paper, we conduct a comprehensive theoretical study for the split stability of KDDT, demonstrating that split will converge in probability with a specific convergence rate, subject to mild assumptions. Our theoretical findings encompass the most commonly used splitting criteria and are applicable to both classification and regression applications. Additionally, we propose and implement algorithms for KDDT induction. Note that KDDT provides a global approximation to the black-box model, meaning it approximates the entire black-box model at once using a single interpretable model. This approach may be less efficient for interpreting very large and complex black-box models, such as deep neural networks, compared to the local approximation models discussed in [26], which approximate the black-box model piecewisely. For local approximation models, each segment of the black-box model can be represented by a different local approximation, allowing for more tailored interpretations. However, large-scale black-box models have a large number (e.g., millions) of parameters and require large training datasets, making them not suitable for small or medium datasets, such as those with fewer than or equal to $O({10^{3}})$ samples. For these datasets, the global approximation provided by KDDT can offer more accurate interpretations for the global effects of covariates than simple linear models. Through a simulation study, we validate KDDT’s ability to provide precise interpretations while maintaining a stable structure. We also include real data analysis to demonstrate its practical applicability.
The remainder of the paper is organized as follows. In Section 2, we introduce the concept and stability theory of KDDT. The algorithms for constructing KDDT are proposed in Section 3. Section 4 presents the simulation study. In Section 5, we apply KDDT on real datasets. Finally, we conclude and engage in a discussion in Section 6. Theorems and proofs are in Appendix A. Supplementary materials can be found in Appendix B. Additionally, an open-source R implementation of KDDT is accessible on GitHub at https://github.com/lxtpvt/kddt.git.
2 Knowledge Distillation Decision Tree
A knowledge distillation decision tree is essentially a decision tree. Instead of being constructed from real observations, it is generated from the knowledge distillation process.
2.1 Knowledge Distillation Process
A typical knowledge distillation process with the teacher-student architecture is illustrated in Figure 1. The specific components of this process can be adapted based on application requirements. For example, the teacher model can be a Convolutional Neural Network (CNN) [7, 12] or a Large Language Model (LLM) [25], while the student model can be a decision tree [11, 6, 4] or a lightweight neural network [7, 12]. In this paper, we specify the components as follows:
-
• Data. $D=\{Y,X\}$, where Y is the set of observations of response variable y, X is the set of observations of covariates $\mathbf{x}=({x_{1}},\dots ,{x_{p}})$. Both response and covariates can be categorical or continuous variables.
-
• Teacher model. $y=f(\mathbf{x})$, we specify it as a small scale black-box ML model for the size of data $O({10^{3}})$.
-
• Knowledge distillation. Includes two steps: 1) random sampling of covariate values on their support, denoted as ${X^{\prime }}$, and 2) generating the corresponding response values ${Y^{\prime }}=f({X^{\prime }})$ through the fitted teacher model f.
-
• Knowledge. ${D^{\prime }}=\{{Y^{\prime }},{X^{\prime }}\}$, we call it pseudo-data.
-
• Student model. A decision tree, we refer to it as a knowledge distillation decision tree (KDDT). The KDDT is constructed from the pseudo-data ${D^{\prime }}$ and keeps a stable structure under the randomness of ${D^{\prime }}$.
The knowledge distillation process can also be viewed as a model approximation process, as illustrated in Figure B.12 in Appendix B. The student model, KDDT, is used to approximate the teacher model through the pseudo-data ${D^{\prime }}=\{{Y^{\prime }},{X^{\prime }}\}$. Additionally, Figure B.12 also highlights the differences between the model approximation and generalization processes.
2.2 Tree Structure Stability
In this paper, we focus on the second half of the knowledge distillation process, specifically from the knowledge distillation to the student model. The teacher model and original data are fixed. Our task is to handle the randomness in the pseudo-data ${D^{\prime }}$ and construct a stable KDDT. It is based on the hypothesis that we can generate arbitrarily large ${D^{\prime }}$ to achieve the stability of KDDT. In this section, we will prove this hypothesis.
2.2.1 Prerequisites
It is essential to introduce the key concepts and notations of decision tree that will be used in the theoretical study.
(1) Splitting criteria
Typically, different criteria are used for regression and classification trees. In regression, the primary criteria include minimizing the sum of squared errors (SSE) or the mean squared error (MSE) after splitting:
where the subscripts l, r represent the left and right nodes of a stump, ${n_{l}}+{n_{r}}=n$, ${\bar{y}_{l}}=\frac{1}{{n_{l}}}{\textstyle\sum _{i=1}^{{n_{l}}}}{y_{li}}$ and ${\bar{y}_{r}}=\frac{1}{{n_{r}}}{\textstyle\sum _{j=1}^{{n_{r}}}}{y_{rj}}$.
(2.1)
\[ \begin{aligned}{}& min\Bigg\{{\sum \limits_{i=1}^{{n_{l}}}}{({y_{li}}-{\bar{y}_{l}})^{2}}+{\sum \limits_{j=1}^{{n_{r}}}}{({y_{rj}}-{\bar{y}_{r}})^{2}}\Bigg\},\\ {} & min\Bigg\{\frac{1}{{n_{l}}}{\sum \limits_{i=1}^{{n_{l}}}}{({y_{li}}-{\bar{y}_{l}})^{2}}+\frac{1}{{n_{r}}}{\sum \limits_{j=1}^{{n_{r}}}}{({y_{rj}}-{\bar{y}_{r}})^{2}}\Bigg\},\end{aligned}\]In classification, the criterion for selecting the best split is to maximize the reduction of impurity after splitting:
where E is the total impurity before splitting, and ${E_{l}}$ and ${E_{r}}$ are the left and right child impurities, respectively, after splitting. Since the split does not impact E, the above criterion can be simplified as follows:
The well-known impurity measures include Shannon entropy, gain ratio, and Gini index [16, 3]. [24] proposed the Tsallis entropy in (2.3) to unify these measures in a single framework.
where Y is a random variable that takes value in $\{{y_{1}},\dots ,{y_{C}}\}$, $p({y_{i}})$ is the corresponding probability of ${y_{i}}$, $i=1,\dots ,C$, and q is an adjustable parameter.
(2.3)
\[ E={S_{q}}(Y)=\frac{1}{1-q}\Bigg({\sum \limits_{i=1}^{C}}p{({y_{i}})^{q}}-1\Bigg),\hspace{14.22636pt}q\in \mathbb{R},\](2) Split search algorithm
The most commonly used split search algorithm is the greedy search algorithm, which makes a locally optimal choice at each stage in a heuristic manner, to find the global optimum. The algorithm involves the steps: (a) for each split, searching through all covariates and their observed values; (b) for each candidate pair (covariate, value), calculating the loss (gain) defined by splitting criterion; and (c) identifying the best split by minimizing the loss (maximizing the gain). Although the greedy search algorithm may not guarantee the global optimum which is theoretically an NP problem [10], we still choose it for our study due to its simplicity and popularity in practice.
2.2.2 Split Convergence
As discussed in the introduction, studying the stability of the entire tree is challenging. A practical approach is to focus on individual splits. If all splits are stable, the entire tree is stable naturally. We refer to a split as achieving stability when it converges to a unique optimal split. The concepts of optimal split is defined as follows.
Definition 1 (Optimal split).
Let Ω be the support of univariate x, and ${z_{i}^{l}}(x)$ and ${z_{i}^{r}}(x)$ be functions $\Omega \to \mathbb{R}$, where $i=1,\dots ,C$ and C is a constant in ${\mathbb{N}^{+}}$. Let $g({z_{1}^{t}}(x),\dots ,{z_{C}^{t}}(x))$ be a continuous function ${\mathbb{R}^{\mathbb{C}}}\to \mathbb{R}$, where $t=l$ or r. Then, the optimal split ${x_{s}}\in \Omega $ is defined as follows.
Definition 1 is somewhat abstract. To clarify, we provide two examples to illustrate this definition in both regression and classification contexts.
-
• Classification: we assume that $y\in \{{y_{1}},\dots ,{y_{C}}\}$ is a categorical variable with C categories, x is a continuous variable, and the split criterion is defined by (2.2) using the Tsallis entropy as in (2.3). The components in Definition 1 are specified as follows.\[ g\big({z_{1}^{l}}(x),\dots ,{z_{C}^{l}}(x)\big)=\frac{1}{1-q}\Bigg({\sum \limits_{i=1}^{C}}{z_{i}^{l}}{(x)^{q}}-1\Bigg),\]\[ g\big({z_{1}^{r}}(x),\dots ,{z_{C}^{r}}(x)\big)=\frac{1}{1-q}\Bigg({\sum \limits_{i=1}^{C}}{z_{i}^{r}}{(x)^{q}}-1\Bigg),\]\[ {z_{i}^{l}}(x)={\int _{a}^{x}}\frac{1}{x-a}\ast {I_{{y_{i}}}}\big(f(t)\big)\hspace{0.1667em}dt,\]\[ {z_{i}^{r}}(x)={\int _{x}^{b}}\frac{1}{b-x}\ast {I_{{y_{i}}}}\big(f(t)\big)\hspace{0.1667em}dt,\]where ${I_{{y_{i}}}}(f(x))$ is an indicator function that is equal to 1 at $f(x)={y_{i}}$ and 0 elsewhere.
The concept of split convergence can be defined based on the definition of an optimal split as follows.
Definition 2 (Split convergence).
A split ${x_{s}^{n}}$ is estimated via greedy search algorithm on the sampled data ${D^{\prime }}=\{{Y^{\prime }},{X^{\prime }}\}$ with size n. Let ${x_{s}}$ be the unique optimal split on Ω. If ${x_{s}^{n}}$ converges to ${x_{s}}$ in probability as $n\to \infty $, we refer to this case as split convergence and ${x_{s}^{n}}$ as a convergent split.
Our theoretical study demonstrates that split convergence can be guaranteed under three assumptions: (1) the existence of unique optimal split; (2) the uniform random sampling of pseudo-data ${D^{\prime }}=\{{Y^{\prime }},{X^{\prime }}\}$; and (3) the greedy search algorithm. Since continuous and categorical response variables have different split criteria for regression and classification, and different types of covariates require distinct treatments in the proof, we divide the theory into four theorems, each corresponding to one of the combinations of variable types listed in Table 1. The details of all theorems, lemma, and their proofs can be found in Appendix A.
Table 1
Theorems classified based on the combinations of variable types.
y | ||
x | Continuous | Categorical |
Continuous | Theorem 1 | Theorem 2 |
Categorical | Theorem 3 | Theorem 4 |
Fair assumptions help establish a theory with a solid foundation and broad applicability. Regarding the greedy search assumption, as discussed in Section 2.2.1, it has the advantages of simplicity and popularity in practice. The uniform random sampling assumption ensures the sampling space covers the teacher model and simplifies theoretical proofs. However, it may lead to efficiency issues as the dimension of covariates increases. For problems with modest dimensions (i.e., fewer than 20 continuous variables), uniform random sampling works well (see real data analysis in Section 5). For high-dimensional problems, non-uniform random sampling strategies may be more appealing and worth investigating. As for the unique optimal split assumption, let us consider its opposite first: assume there are multiple optimal splits. We can define the concept of split oscillation in Definition 3. Although split oscillation may occur in theory, it rarely happens in practice. Let us consider a scenario where two optimal splits exist. When applying the greedy search algorithm with real data, the likelihood of two splits yielding identical numerical results (e.g., impurity reduction) will be extremely low. Even if such a rare situation arises, it is not a significant concern. It simply indicates that the two splits are equivalent, and selecting either of them is reasonable.
Definition 3 (Split oscillation).
A split ${x_{s}^{n}}$ is estimated via greedy search algorithm on the sampled data ${D^{\prime }}=\{{Y^{\prime }},{X^{\prime }}\}$ with size n. If ${x_{s}^{n}}$ has multiple limits as $n\to \infty $, we refer to this case as split oscillation and the split as an oscillating split.
2.2.3 Measure of Split Stability
In practice, the pseudo-data must be finite. Therefore, we need a way to measure the split stability with finite data. Motivated by the greedy search algorithm, we propose the two-level split stability (see Figure 2) as follows.
Figure 2
Two-level split stability. First-level stability is denoted as the pmf of choosing a split variable. Second-level stability is denoted as either a pdf or a pmf conditioning on the selected split variable.
-
• First-level stability. It is defined as a discrete distribution with the probability mass function $p(k)$, which quantifies the stability of selecting the k-th covariate ${x_{k}}$, ($k=1,\dots ,p$) as the splitting variable.
-
• Second-level stability. It is defined as the conditional distribution of the splitting value given the covariate to split on. The second-level stability can be either a probability density function (e.g., $p(x|i)$) or a probability mass function (e.g., $p(x|j)$), depending on the type of selected splitting variable.
Since the two-level stability is difficult to calculate analytically, we use Monte Carlo simulation for its estimation.
3 Algorithms for Constructing KDDT
There are two fundamental distinctions in the construction of KDDT compared to ordinary decision tree (ODT). Firstly, ODT is built directly from a limited dataset, whereas KDDT is constructed using unlimited (in theory) pseudo-data. Secondly, ODT’s goal is to best fit the dataset, whereas KDDT’s objective is to best approximate the teacher model. These distinctions result in a different construction algorithm of KDDT compared to ODT.
3.1 KDDT Induction Algorithm
We first introduce the concept of the sampling region, which will be utilized in the KDDT induction algorithm.
Definition 4 (Sampling region).
Let $\mathbb{S}$ be the bounded space defined by the observed data. For node i in KDDT, its ancestors define a subspace on $\mathbb{S}$. We denote it as ${R_{i}}$ and refer to it as the sampling region of node i. The sampling region of any inner node is exactly the union of the sampling regions of its two child nodes. (Note: Since the boundary of observed data is limited, $\mathbb{S}$ is bounded.)
As an extension of the sampling region, the concept of the sampling path will also be used later in this paper.
Definition 5 (Sampling path).
A sampling path is a series of nested sampling regions defined by the nodes in a KDDT path. The sampling path ${P_{i,j}}$ starts from sampling region ${R_{i}}$ and ends at sampling region ${R_{j}}$, i.e., ${P_{i,j}}=\{({R_{i}},\dots ,{R_{j}})|{R_{i}}\supset \cdots \supset {R_{j}}\}$. Two sampling paths intersect if there exists a sampling region in one sampling path that includes any sampling region in the other sampling path.
Figure 3
Examples of the dependency chain and variance propagation. (a) Two dependency chains in a tree. (b) The corresponding variance propagation. Note: ${R_{i}}$ denotes sampling region i. ${R_{2}}|({R_{1}},{x_{s1}})\to {x_{s2}}$ indicates that ${x_{s2}}$ is determined by ${R_{2}}$, given ${R_{1}}$ and ${x_{s1}}$. The variance of ${x_{si}}$ is denoted by $\Delta {x_{si}}$.
The most commonly used induction algorithm for constructing an ODT is a top-down recursive approach [18], referred to as the ODT induction algorithm in this paper. It starts with the entire input dataset in the root node, where a locally optimal split is identified using the greedy search algorithm, and conditional branches based on the split are created. This process is repeated in the generated nodes until the stopping criterion is met. A naive approach to constructing a KDDT is to directly apply the ODT induction algorithm on a large pseudo-dataset. However, this method may not perform well in practice. The pseudo-data introduces variation (uncertainty) due to the random sampling process. This variation propagates in the constructed tree along dependency chains created by the top-down induction strategy. For instance, as illustrated in panel (a) of Figure 3, the split ${x_{s4}}$ depends on $({R_{2}},{x_{s2}})$, which, in turn, depends on $({R_{1}},{x_{s1}})$. The propagation of split variance follows the inverse direction of these dependencies. We denote the variance of ${x_{si}}$ as $\Delta {x_{si}}$. In panel (b) of Figure 3, the variance $\Delta {x_{s1}}$ will affect $({x_{s2}},\Delta {x_{s2}})$ and $({x_{s3}},\Delta {x_{s3}})$, and subsequently, $\Delta {x_{s2}}$ will impact $({x_{s4}},\Delta {x_{s4}})$. This results in rapid inflation of variance as it propagates to deeper levels. For example, a small $\Delta {x_{s1}}$ may lead to a substantial $\Delta {x_{s4}}$ or even a change in the split variable.
Incorporating the two-level stability, we proposed a KDDT induction algorithm to avoid the variation inflation issue in the ODT algorithm. As shown in the steps of KDDT induction algorithm, for a given node i, we measure its split ${N_{i}}$ times. Utilizing these repeated measurements represented as ${X_{s}}$, we can calculate the two-level stability and choose a split value with the lowest variance. The first-level stability aids in reducing the variance when selecting the split variable, while the second-level stability assists in reducing the variance when identifying the split value. For example, if the split variable is continuous, considering the mean of all fitted split values ${\bar{x}_{s}}$, by central limit theorem, the variance of ${\bar{x}_{s}}$ will reduce at a rate of ${N_{i}^{-1}}$. Instead of using the mean of all fitted split values, the two-level stability approach choose the mode, which not only reduces variance but also mitigates the influence of outliers. Additionally, choosing the mode aligns with the likelihood principle, as it corresponds to selecting the value with the maximal likelihood, given that two-level stability is defined by probability mass/density function. By repeating this process at each split, we can construct a KDDT with a stable structure. In practice, it is common to choose a reasonably large value for ${N_{i}}$ (${N_{i}}=100$ works well for our study). The sample size of the pseudo data ${n_{i}}$ can be estimated by using (proportional to) the potential explanation index (see Definition 6). In practice, if the number of nodes is small (see interpretable nodes in Section 3.2), we can simply set all ${n_{i}}$ to be the same at the same tree level and assign their values equal to 90% of the corresponding value in the preceding upper level. We select 90% to maintain a large number of pseudo-data, ensuring a stable estimation of the split value. We determine the value of ${x_{s}^{\ast }}$ by selecting the mode of the second-level stability (pmf/pdf). The stopping criterion can be defined as the ratio of prediction accuracy (e.g., MSE or C-Index) between the teacher model and KDDT on the observed data, evaluated through cross-validation.
3.2 Hybrid Induction Algorithm and Hybrid KDDT
We assume that the teacher model is well-defined, meaning it is well-fitted to the observed data without overfitting. Since KDDT aims to optimize its approximation to the teacher model, we can ignore any overfitting concerns for KDDT in relation to the observed data. Thus, we can focus solely on achieving a balance between the degree of approximation and computational efficiency during KDDT construction. KDDT induction algorithm requires repetitive sampling and fitting, performed ${N_{i}}$ times to identify the best split. Each time, the pseudo-data need to have a sufficient size, leading to high computational load. Furthermore, to achieve a high-quality approximation, the tree needs to grow to a large size. Consequently, growing a large tree solely using KDDT induction algorithm is often computationally infeasible.
Typically, in most real-world applications, only a small set of splits is needed for interpretation purposes. We refer to these splits as interpretable nodes (splits), while all other nodes, i.e., terminal nodes, are named predictive nodes. Since the ODT induction algorithm is much more efficient in constructing large trees than KDDT’s, it is reasonable to combine these two algorithms to a hybrid induction algorithm. Specifically, we apply the KDDT induction algorithm to the interpretable nodes, ensuring their two-level stability, which is crucial for interpretation. Then, we employ the ODT induction algorithm to construct large sub-trees at the predictive nodes, maintaining a good approximation to the teacher model and increasing the computation efficiency. We refer to this tree as a hybrid KDDT. For instance, in Figure 4, the interpretable nodes are $\{1,2,3,4,7,14\}$. We use the KDDT induction algorithm to identify the stable splits for these nodes. Then, we employ the ODT induction algorithm to grow the large sub-trees $\{{T_{5}},{T_{6}},{T_{8}},{T_{9}},{T_{15}},{T_{28}},{T_{29}}\}$ at the respective predictive nodes.
Figure 4
An example of the hybrid knowledge distillation decision tree: the nodes $\{1,2,3,4,7,14\}$ are interpretable nodes, while all other nodes are predictive nodes.
The small set of interpretable nodes enhances the simplicity of model interpretation. Meanwhile, the complexity necessary to ensure a prediction accuracy comparable to that of the teacher model is achieved through the construction of large sub-trees at the predictive nodes. This decoupling between interpretability and complexity offers the potential for hybrid KDDT to strike a balance between prediction accuracy and interpretability. For the sake of simplicity, we use the term “KDDT” to refer to hybrid KDDT in the remainder of this paper.
Figure 5
The effectiveness of KDDT in revealing and explaining complex data structures. (a) The true function $y=f({x_{1}},{x_{2}})$ and its partition with 9 splits. (b) 50 random samples from the true function. (c) The tree structure and splits for defining the partitions in (a), (d), and (f). (d) ODT is fitted based on 50 samples. (e) RF is fitted based on the 50 samples. (f) KDDT is constructed based on the RF.
The informativeness of the interpretation can vary across different interpretable nodes. Measuring and reporting these differences is crucial for the interpretation according to these nodes. To address this issue, we introduce the concept of explanation index (XI). A similar index is calculated and referred to as the potential explanation index (PXI) for the predictive nodes.
Definition 6 (Explanation Index and Potential Explanation Index).
The explanation index of interpretable node (split) i, denoted as $X{I_{i}}$, and the potential explanation index of predictive node j, denoted as $PX{I_{j}}$, are defined as follows:
where ${n_{i}}$, ${n_{j}}$, and n denote the number of observations in node i, j, and the entire dataset. ${\Delta _{{S_{i}}}}$ and ${\Delta _{{T_{j}}}}$ represent the change in the measure defined by the split criterion (e.g., impurity or MSE reduction) after fitting split i or the subtree at node j, respectively. ${\Delta _{KDDT}}={\textstyle\sum _{i}}\frac{{n_{i}}}{n}\ast {\Delta _{{S_{i}}}}+{\textstyle\sum _{j}}\frac{{n_{j}}}{n}\ast {\Delta _{{T_{j}}}}$.
(3.1)
\[ X{I_{i}}=\frac{\frac{{n_{i}}}{n}\ast {\Delta _{{S_{i}}}}}{{\Delta _{KDDT}}}\ast 100\% ,\hspace{8.53581pt}PX{I_{j}}=\frac{\frac{{n_{j}}}{n}\ast {\Delta _{{T_{j}}}}}{{\Delta _{KDDT}}}\ast 100\% \]Based on Definition 6, it is straightforward to verify that ${\textstyle\sum _{i}}X{I_{i}}+{\textstyle\sum _{j}}PX{I_{j}}=1$. $X{I_{i}}$ and $PX{I_{j}}$ can be considered as information contained in the interpretable node i and predictive node j, respectively. Furthermore, we can extend the concept of XI to apply to a path in KDDT as follows.
Definition 7 (Path Explanation Index).
where node i is an interpretable node, node j is a descendant of node i, and ${S_{ij}}$ is a set of node IDs that includes the nodes in the path from node i to the parent of node j.
With the above indices, we can identify the desired hybrid KDDT with an appropriate number of interpretable nodes. For instance, if we want to achieve more than 70% of the information in the data explained by the interpretable nodes, the stopping criterion is ${\textstyle\sum _{i}}X{I_{i}}\gt 70\% $, i.e., ${\textstyle\sum _{j}}PX{I_{j}}\lt 30\% $. Examples demonstrating their applications can be found in Section 5. The process of constructing a desired hybrid KDDT with appropriate number of interpretable nodes is illustrated by an example shown in the Figure B.14 in Appendix B. The potential explanation index of a predictive node can also be used to determine the size of the pseudo data in its sampling region which could be proportional to its PXI. Because, a higher PXI indicates greater unexplained information, requiring a larger pseudo data size.
4 Simulation Study
The simulation study has three primary objectives: (1) to demonstrate the effectiveness of KDDT in revealing intricate structures of the data, (2) to validate the interpretability of KDDT, and (3) to illustrate the stability of interpretable splits (nodes).
Figure 6
The comparison of interpretations and simulation result. (a) The true partition is obtained from the ODT that is fitted based on the entire data of the true function. (b-1) The ODT is fitted based on 50 samples. (b-2) The result of $|{\hat{Y}_{true}}-{\hat{Y}_{ODT}}|$. (c-1) The KDDT is built from RF. (c-2) The result of $|{\hat{Y}_{true}}-{\hat{Y}_{KDDT}}|$. (d) MSE comparison of ODT and KDDT with 100 times simulations.
To facilitate a clear and intuitive discussion, we introduce a two-dimensional function $y=f({x_{1}},{x_{2}})$, consisting of 2601 generated data points, as illustrated in panel (a) of Figure 5. This function exhibits high non-linearity and intricate interactions, making it well-suited for our purposes. Let us assume that $y=f({x_{1}},{x_{2}})$ is unknown. We can gain insights about it by analyzing the sampled observations. In panel (b), we have 50 observations randomly sampled from the true function. We want to compare KDDT with other interpretable models. The well-known ones include linear regression and decision tree (ODT). As linear regression is not suitable for this data, we opt for ODT. The ODT fitted from 50 observations is presented in panel (d). In comparison to the true function, the ODT estimation is coarse and unable to capture the interaction structure within the area marked by the green rectangle. In contrast, the random forest model provides a refined and precise estimation, as shown in panel (e). The KDDT presented in panel (f), as a close approximation of its teacher, maintains a high-quality (resolution) estimation. This highlights the ability of the KDDT to reveal intricate structures in the data.
For the function $y=f({x_{1}},{x_{2}})$, effective interpretation is visually demonstrated through a suitable partition of the response values y based on the covariates ${x_{1}}$ and ${x_{2}}$, as shown in panel (a) of Figure 5. This partition comprises nine splits generated by the ODT in panel (c), which is fitted using the entire dataset of 2601 data points. We refer to this partition as the true partition, representing the optimal interpretation. Although the random forest model provides an accurate estimation of the true function, it cannot generate a partition for interpretation. The ODT (fitted with 50 samples) is interpretable, but its interpretation (partition) is not accurate. In contrast, the KDDT’s interpretation closely approximates the optimal one (true partition), which is better than the ODT’s. This claim relies on visual inspection, which is a qualitative approach. Figure 6 presents a quantitative method for comparing the quality of interpretation between ODT and KDDT. Panel (a) displays the true partition (optimal interpretation). The partitions of ODT and KDDT are depicted in panels (b-1) and (c-1), respectively. Panels (b-2) and (c-2) illustrate the absolute errors of ODT and KDDT compared to the truth. Clearly, visual inspection still leads to the same conclusion that KDDT’s interpretation (partition) is superior to ODT’s. More importantly, we can quantify this difference using MSE. In this example, KDDT’s MSE is 6.95, significantly smaller than ODT’s MSE of 14.56. Furthermore, we repeat this comparison 100 times. The result in panel (d) demonstrates that, in general, KDDT outperforms ODT in terms of interpretation quality measured by MSE. Note that the medians of MSE are 20.32 and 12.14 for ODT and KDDT, respectively. The corresponding means of MSE are 22.39 and 12.17. KDDT results in a 40.3% reduction in the median and a 45.6% reduction in the mean compared to ODT. The maximum MSE of KDDT is 20.93, corresponding to 53 percentile of ODT’s. The maximum MSE of ODT is 54.06, which is more than 2.5 times higher than the KDDT’s.
To examine the KDDT in panel (c) of Figure 5 in more detail, it contains nine interpretable nodes (splits). The first-level and second-level stability can be found in Figure 7. Except for split 12, which maintains a still impressive first-level stability of 97%, all other splits exhibit a first-level stability of 100%. Regarding second-level stability, each density function is tightly concentrated within a narrow interval and displays a sharp peak. Consequently, we can confidently assert that the interpretable splits within the KDDT are stable.
5 Applications
When should we use KDDT? Two fundamental conditions should be met.
-
• Demands for understanding or explanation: We need to understand or explain the data, either to gain personal insight or to communicate findings to others.
-
• Possess good prediction accuracy: The black-box ML model, which KDDT aims to approximate, should outperform simple interpretable models, such as linear regression or ODT, in predicting the data. This suggests that the black-box model may have a better understanding of the data and the potential to offer a more accurate interpretation compared to the simple models.
Considering these conditions, we discuss two real applications of KDDT in this section.
5.1 Example for Model Interpretation
In the application of model interpretation, we use the Boston Housing dataset, which comprises a total of 506 observations with 14 variables. The description of variables can be found in Table B.2 in Appendix B. Our goal is to understand the effects of covariates on the price of houses in Boston (in 1970). To check the second condition, we select the linear regression model (LM) and ODT as simple interpretable models while choosing the random forest (RF) and SVM as two candidate black-box models. A five-fold cross-validation was conducted to compare their prediction accuracy. The MSE (mean square error) on testing data are LM: 23.2, ODT: 24.9, RF: 10.9, SVM: 13.4, and KDDT(RF): 14.9. More details of comparison can be found in Figure B.13 in Appendix B. From the results, the ML models outperform the simple interpretable models, and RF performs better than SVM. Hence, we can choose RF as the teacher model. The student model KDDT(RF) outperforms the simple interpretable models and exhibits similar performance to SVM. It indicates that KDDT(RF) may offer a more accurate interpretation than the simple interpretable models.
Figure 8
Model interpretation through KDDT. (a) The interpretation of RF using KDDT(RF). (b) The interpretation of SVM using KDDT(SVM). Note that the left node (yes) and the right node (no) indicate whether the split condition is met or not, respectively. Note: the process of constructing KDDT(RF) in panel (a) can be found in the Figure B.14 in Appendix B.
The panel (a) of Figure 8 illustrates the interpretations of KDDT(RF) for its teacher model RF. Since KDDT(RF) is essentially a decision tree, identifying the variables of importance is straightforward. The three most important variables are lstat, rm, and nox, related to social status, house size, and the natural environment, respectively. This is consistent with the corresponding results of the teacher model RF (see Figure B.15 in Appendix B), which is evidence that KDDT(RF) can provide accurate interpretation for its teacher model. More detailed and specific interpretations can be obtained by examining the interpretable splits (nodes) featured in panel (a). For example, if a house has seven or more rooms and is situated in an affluent community where the percentage of the population with lower social status (lstat) is less than 4.71%, it is likely to have a high value, averaging $36,600. Additionally, for potential buyers, an intriguing insight emerges: they might acquire a larger house with seven or more rooms in a less affluent community with lstat $\ge 9.8\% $, priced around $25,500, which is cheaper than a smaller house that could cost around $26,300 in a community with lstat $\le 9.7\% $. These specific insights are exemplified by nodes 5 and 6 in the tree. The stability of the interpretable splits shown in Figure B.16 in Appendix B ensures the credibility of interpretations.
Figure 9
The interpretation of super learner through KDDT. (a) The framework of super learner algorithm/model. The number 1, ..., K refer to different cross-validation folds. The gray folds refers testing data. (b) The estimated weights and super learner. (c) The comparison of prediction accuracy between super learner and base models. (d) The variable importance of KDDT(SL). (e) The tree structure of KDDT(SL).
In panel (a) of Figure 8, the XI and PXI associated with the interpretable and predictive nodes provide the relative importance information for their interpretation. For instance, $X{I_{1}}=29.8\% $ for the split rm<6.97 indicates whether a house has seven or more rooms is crucial for assessing its value. Moreover, these indices could serve as stop criterion for identifying the interpretable nodes set. For example, we can identify the KDDT(RF) interpretable nodes by the criterion that the sum of PXI is less than 30%. This criterion ensures that predictive nodes do not contain substantial information. Furthermore, we can interpret any prediction of KDDT by using the concept of the path explanation index in Definition 7. For example, if a prediction is made through the predictive node 9 (see panel (a)), its XI can be calculated as $X{I_{1,9}}=X{I_{1}}+X{I_{2}}+X{I_{4}}=51.7\% $. Then, with the $PX{I_{9}}=3.9\% $, we can obtain that $(\frac{X{I_{1,9}}}{X{I_{1,9}}+PX{I_{9}}},\frac{PX{I_{9}}}{X{I_{1,9}}+PX{I_{9}}})=(93\% ,7\% )$. It indicates that the prediction can be interpreted with a degree of 93% using the chain of decision rules $\{rm\lt 6.97\xrightarrow{}lstat\ge 9.75\xrightarrow{}nox\ge 0.669\}$.
Last but not least, the percentage of observed data of each node also plays a pivotal role in comprehending the interpretation of KDDT. This percentage serves as crucial evidence of how strongly the interpretation of a particular node is supported by the observed data. Given that KDDT is not a direct interpretation of the observed data but rather of the teacher model, the support from the observed data is pivotal for the interpretation’s practical significance. Even a node (split) with a high XI may lack practical relevance if the percentage of observed data associated with it (or its children) is exceedingly low. For instance, consider node 6 (split 3). Although it has $X{I_{1,6}}=X{I_{1}}+X{I_{3}}=48.7\% $ ($X{I_{3}}=18.9\% $), it (left child) comprises a mere 1.2% of observed data. This suggests that the interpretation of this node (split) might not carry much practical importance. In other words, the chance of purchasing a larger house at a lower price is not zero, but it is very low in practice. Consequently, it is imperative to take into account both the XI and the percentage of observed data when interpreting KDDT. As an example, we are confident in the interpretation of predictions made through node 9. Because this node not only has a high path XI of $X{I_{1,9}}=51.7\% $ that can be interpreted with a degree of 93% but also enjoys strong practical support from a large number (38.1%) of observed data.
As demonstrated in panel (b) of Figure 8, KDDT can also provide an interpretation for SVM, which differs from the one for RF. In KDDT(SVM), the top three important variables are lstat, crim, and rm, related to social status, security, and house size, respectively. It indicates that, except for social status and house size, the SVM’s explanation focuses on security, in contrast to RF emphasis on natural environment. Regarding the interpretable splits, the sum of their XIs in KDDT(SVM) is 26.3%, which is smaller than the 74.5% in KDDT(RF). This suggests that the interpretable nodes set of KDDT(SVM) has less interpretability compared to its counterpart in KDDT(RF). Their comparison shown at the bottom of Figure 8 provides an intuitive illustration supporting this assertion, demonstrating that more variation in the data is explained by KDDT(RF) than by KDDT(SVM). Another issue of KDDT(SVM) is that splits 3 and 7 have child nodes 6 and 14, respectively, which do not include any observed data. To address this, we can omit these two branches (red dashed lines) and focus solely on node 15. The path explanation index from node 1 to 15 can then be calculated as $X{I_{1,15}}=X{I_{1}}+X{I_{3}}+X{I_{7}}=18.8\% $. In sum, through KDDT, SVM can offer a different interpretation compared to RF. But, the interpretable splits in KDDT(SVM) do not perform as effectively as their counterparts in KDDT(RF).
KDDT can also be valuable in interpreting the model that is ensembled from other models. One typical example is the Super Learner introduced by [13]. As depicted in panel (a) of Figure 9, the Super Learner employs cross-validation to estimate the performance of multiple base models. Subsequently, it constructs an optimal weighted average of these models based on their testing performance. This approach has been proven to yield predictions that are asymptotically as good as or even better than any single model within the ensemble. In this example, we introduced eight base models and estimated their weights in the Super Learner, as shown in panel (b). Evaluated through a 10-fold cross-validation, the result presented in panel (c) demonstrates that the Super Learner outperforms all its base models in terms of prediction accuracy, which satisfies the second condition for applying KDDT.
Compared to the base models, the ensemble nature of the Super Learner renders it a more opaque black-box model, which makes the interpretation more challenging. KDDT can provide a solution. Panel (d) of Figure 9 presents the variable importance of KDDT(SL), which remarkably resemble those of the RF model shown in panel (a) of Figure 8. In panel (e) of Figure 9, interpretable splits (nodes) were selected based on the criterion that the sum of PXI is less than 30%. The sum of XIs is 75.2%, indicating that the interpretable nodes of KDDT(SL) offer substantial interpretability. An interesting observation emerges when comparing KDDT(RF) and KDDT(SL): the predictions and interpretations of nodes 4 and 5 in KDDT(RF) closely resemble those of nodes 4 and 6 in KDDT(SL). In panel (a) of Figure 8 and panel (e) of Figure 9, these corresponding paths are highlighted in green and purple, respectively. Notably, all of the paths exhibit both high path XI and substantial percentages of observed data. This suggests a strong similarity in interpretation between RF and the Super Learner.
We have three KDDT interpretations associated with RF, SVM, and Super Learner. It is important to be aware that all of these interpretations are reasonable and valid. All roads lead to Rome. Choosing which one depends on the application requirements. For example, consider a real estate consultant whose client is interested in the natural environment of the house, KDDT(RF)’s explanation would be a good choice. If the client’s main concern is the safety of the neighborhood, KDDT(SVM)’s interpretation may be a better choice. Furthermore, if significant splits or paths consistently appear in different KDDT interpretations, it serves as an indicator of their critical roles in the data. These interpretations have the potential to provide valuable insights or knowledge about the data or application. For example, as discussed in the comparison of KDDT(RF) and KDDT(SL), we can derive the valuable insight that 10% lower status of the population and 7 rooms are two critical thresholds shaping people’s evaluations of house prices in Boston.
5.2 Example for Subgroup Discovery
Figure 10
Subgroup discovery and optimal cutoff identification. (a) Variable importance provides information for selecting split variables. (b) Interpretable splits for identifying optimal cutoff and subgroups. (c) The two-level stability of split 2. (d) Log10(p-value) of log-rank tests for validating the optimality of cutoff value. (e) The Kaplan-Meier plot for the identified subgroups. (f) Contingency table depicts the association between age and chf.
With the ability to uncover patterns in complex data and explore non-linear relationships, ML models have gained popularity in data-driven precision medicine, fueled by the rapid expansion in the availability of a wide variety of patient data. In precision medicine, identifying heterogeneity plays a central role, where subgroups of patients are defined based on baseline values of demographic, clinical, genomic, and other covariates, known as biomarkers. Understanding the effects of biomarkers in data analysis models is crucial for subgroup discovery. KDDT can bridge the gap between understanding the role of biomarkers and the lack of interpretability in black-box ML data analysis models. Particularly, as a tree-based approach, KDDT can incorporate information on higher-order interaction effects and be applied to define subgroups based on multiple biomarkers. Moreover, cutoff values do not need to be pre-specified for continuous/ordinal biomarkers. They are automatically estimated from the process of constructing KDDT.
In this example, we select the time-to-event dataset WHAS (Worcester Heart Attack Study), whose aim is to describe factors associated with trends in incidence and survival over time after admission for acute myocardial infarction. This dataset is available in the R package “mlr3proba”, and includes 481 observations and 14 variables. Four variables, id (Patient ID), year (Cohort year), yrgrp (Grouped cohort year), and dstat (Discharge status from the hospital: 1 = Dead, 0 = Alive), were excluded as they were not pertinent to the goal of study. The description of the remaining variables can be found in Table B.3 in Appendix B. We choose Cox Proportional Hazard (CoxPH) model as the interpretable model and Random Survival Forest (RSF) as the black-box teacher model. Similar to Section 5.1, the comparison of prediction accuracy was conducted with a five-fold cross-validation. Instead of MSE, the C-index serves as the criterion, with higher C-index indicating higher accuracy. The result on testing data is CoxPH: 0.766, RSF: 0.797, and KDDT(RSF): 0.797. More details of the comparison can be found in Figure B.17 in Appendix B. This result demonstrates the superiority of RSF in prediction and suggests that it is worth trying to take advantage of KDDT(RSF) in the application of subgroup discovery.
The structure of KDDT(RSF) is depicted in panel (b) of Figure 10. As the first split, sho=0, $X{I_{1}}=84.1\% $ suggests a great practical significance for the identified subgroups. Actually, it is widely recognized that cardiogenic shock is positively associated with an increased risk of death. It is not a surprising discovery. The researcher’s interest may lie more in the subgroups identified from the patients who didn’t experience cardiogenic shock. The second split, age$\lt 67.02$, reveals two subgroups. Although $X{I_{2}}=3.9\% $ is not high compared to $X{I_{1}}=84.1\% $, it is relatively high in the rest of interpretable nodes, $\frac{X{I_{2}}}{X{I_{2}}+X{I_{4}}+X{I_{5}}+X{I_{10}}}=\frac{3.9\% }{3.9\% +2.9\% +1.5\% +1.1\% }=41.5\% $. Moreover, the observed data in its child nodes are substantial and well-balanced, indicating strong support from the observed data. They are evidence that indicates the importance of the subgroups identified by the second split. The split stability of the optimal cutoff value is displayed in panel (c). The greedy search algorithm ensures its optimality which is substantiated by the p-values from log-rank tests across different values in panel (d). Consequently, there is no need to explore multiple cutoff values, thus alleviating the multiplicity issues. Panel (e) displays the Kaplan-Meier plot for the two subgroups, illustrating the varying risks associated with each subgroup. Finer subgroups and covariates interactions can be explored by considering deeper splits. Since node 11 just contains 4 observations (0.83% of the data), we can remove it and its parent node 5 (see the red dashed line). This can be achieved by redistributing these 4 observations to nodes 20 and 21 based on their chf values. As a result, four subgroups with the number of patients can be identified in the table in panel (f). Analyzing this table reveals a clear interaction (dependency) between the risk of left heart failure (chf=1) and the age of patients. This relationship can be statistically confirmed through a ${\chi ^{2}}$ test, which yields a p-value of 2.655e-10.
6 Discussion
KDDT offers a general method for interpreting black-box ML models, enabling the exploration of intricate data structures captured by these models for more precise and detailed interpretations. Essential attributes for good interpretable models include simplicity, stability, and predictivity. Stability is the central focus of this study. The primary challenge lies in constructing a stable KDDT while handling the randomness of the pseudo-data (knowledge) sampled in the knowledge distillation process. We propose a comprehensive theory for split stability and develop efficient algorithms for constructing stable KDDTs. To ensure simplicity, KDDT efficiently decouples the tasks of interpretation and prediction, maintaining a concise set of interpretable nodes for the purpose of interpretation. Regarding predictivity, KDDT, as a closed approximation of black-box ML models, retains strong predictive performance comparable to the original black-box models. In conclusion, KDDT is an excellent interpretable model with great potential for practical applications.
In our theory and algorithms, we employed the random sampling method to generate pseudo-data for constructing KDDT. This approach performed well in simulation and real data studies. Specifically, when the sample size is less than 60000, the time required to fit an interpretable node was under one minute. In general, for cases where the number of continuous covariates (${n_{con}}$) is relatively small, typically less than 20, the sample size of 60000 is sufficient. However, when dealing with larger ${n_{con}}$, a larger sample size is necessary. In such cases, random sampling will be less efficient, and non-uniform random sampling strategies may be more attractive. Two promising strategies are MCMC sampling, which leverages information from the teacher model to enhance sampling efficiency, and PCA sampling, which uses dimension reduction to improve sampling efficiency. They are interesting directions for future study.