机器学习算法-回归树(CART)

发布于 2020-10-02  1563 次阅读


回归树(regression tree)

前提知识

决策树

ID3算法中,选择的是信息增益来进行特征选择,信息增益大的特征优先选择。
而在C4.5中,选择的是信息增益比来选择特征,以减少信息增益容易选择特征值多的特征的缺点。
在CART中,是使用基尼系数来进行特征选择
$$
G = 1 - \sum_{i=1}^k p_i^2
$$
从公式中可以看出来,基尼指数的意义是从数据集D中随机抽取两个样本类别标识不一致的概率。
相比于信息增益,信息增益比等作为特征选择方法,基尼指数省略了对数计算,运算量比较小,也比较容易理解,所以CART树选择使用基尼系数用来做特征选择。

CART

CART 全称是 Classification And Regression Tree(分类和回归树),所以正常是可以使用CART进行回归计算,但是在不同类型中CART有一些改变
CART 在分类问题和回归问题中的相同和差异:

  1. 相同:
    • 在分类问题和回归问题中,CART 都是一棵二叉树,除叶子节点外的所有节点都有且仅有两个子节点
    • 所有落在同一片叶子中的输入都有同样的输出
  2. 差异:
    • 在分类问题中,CART 使用基尼指数作为选择特征和划分的依据;在回归问题中,CART 使用总方差进行划分
    • 在分类问题中,CART 的每一片叶子都代表的是一个类型;在回归问题中,CART 的每一片叶子表示的是一个预测值,取值是连续的。

算法简述

之前的线性回归中,求解的问题如下
$$
\min \frac{1}{n} \sum_{i = 1}^{n} (f(x_i) - y_i)^2
$$
用 CART 进行回归,目标自然也是一样的,让均方误差最小

假设一棵构建好的 CART 回归树有 M片叶子,这意味着 CART 将输入空间 x 划分成了 M 个单元 R1,R2,...,RM同时意味着 CART 至多会有 M 个不同的预测值。CART 最小化均方误差公式如下:
$$
\min \frac{1}{n} \sum_{m = 1}^{M}\sum_{x_i \in R_m} (c_m - y_i)^2
$$
其中$c_m$是第M个叶子的预测值
因为是二叉树,所以我们针对第$j$个变量的值$s$进行划分
$$
\begin{split}
R_1{j, s} = {x|x^{(j)} \le s} \
R_2{j, s} = {x|x^{(j)} > s}
\end{split}
$$
CART 选择切分变量 $j$ 和 切分点 $s$ 的公式如下:
$$
\min_{j, s} \left[\min_{c_1} \sum_{x_i \in R_1{j, s}} (y_i - c_1)^2 + \min_{c_2} \sum_{x_i \in R_2{j, s}} (y_i - c_2)^2 \right]
$$
采取遍历的方式,我们可以将$j$和$s$ 找出来

ID3和C4.5回归树

CART有两种评价标准:Variance和Gini系数。而ID3和C4.5的评价基础都是信息熵。信息熵和Gini系数是针对分类任务的指标,而Variance是针对连续值的指标因此可以用来做回归。所以不可以用ID3和C4.5回归

剪枝

通过降低树的复杂度来避免过拟合的过程称为剪枝。对树的剪枝分为预剪枝和后剪枝。一般地,为了寻求最佳模型可以同时使用这两种剪枝技术。

预剪枝:在选择创建树的过程中,我们限制树的迭代次数(即限制树的深度),以及限制叶节点的样本数不要过小,设定这种提前终止条件的方法实际上就是所谓的预剪枝。

后剪枝:使用后剪枝方法需要将数据集分为测试集和训练集。用测试集来判断将这些叶节点合并是否能降低测试误差,如果是的话将合并。

模型树

模型树与回归树的差别在于:回归树的叶节点是节点数据标签值的平均值,而模型树的节点数据是一个线性模型(可用最简单的最小二乘法来构建线性模型),返回线性模型的系数$w$,我们只要将测试数据$x$乘以$w$便可以得到预测值$w$,即$y=w^T*x$,所以该模型是由多个线性片段组成的。

当然回归树和模型树都是用来回归的,对于离散的点可以用回归而对于明显线性的点,可以用模型树,通过调用NumPy库中的corrcoef()来判断哪个更优势

代码

import numpy as np
import matplotlib.pyplot as plt
import copy


def loaddata(filename):
    data = []
    with open(filename) as fr:
        while True:
            line = fr.readline()
            if line:
                linearr = line.strip().split()
                mask_x = [float(x) for x in linearr]
                data.append(mask_x)
            else:
                break
                pass
    return data


def splitdata(dataset, feature, value):
    mat0 = dataset[np.nonzero(dataset[:, feature] > value)[0], :]
    mat1 = dataset[np.nonzero(dataset[:, feature] <= value)[0], :]
    return mat0, mat1


def regleaf(dataset):
    return np.mean(dataset[:, -1])


def regerr(dataset):
    return np.var(dataset[:, -1]) * np.shape(dataset)[0]


def creattree(dataset, leaftype=regleaf, errtype=regerr, ops=(1, 4)):
    feat, val = choosebestsplit(dataset, leaftype, errtype, ops)
    if feat == None: return val
    rettree = {}
    rettree['spind'] = feat
    rettree['spval'] = val
    lset, rset = splitdata(dataset, feat, val)
    rettree['left'] = creattree(lset, leaftype, errtype, ops)
    rettree['right'] = creattree(rset, leaftype, errtype, ops)
    return rettree


def choosebestsplit(dataset, leaftype=regleaf, errtype=regerr, ops=(1, 4)):
    tols = ops[0]
    toln = ops[1]
    if len(set(dataset[:, -1].T.tolist()[0])) == 1:
        return None, leaftype(dataset)
    m, n = np.shape(dataset)
    s = errtype(dataset)
    bests = float('inf')
    bestindex = 0
    bestvalue = 0
    for featindex in range(n - 1):
        for splitval in set((dataset[:, featindex].T.A.tolist())[0]):
            mat0, mat1 = splitdata(dataset, featindex, splitval)
            if np.shape(mat0)[0] < toln or np.shape(mat1)[0] < toln:
                continue
            news = errtype(mat0) + errtype(mat1)
            if news < bests:
                bestindex = featindex
                bestvalue = splitval
                bests = news
    if (s - bests) < tols:
        return None, leaftype(dataset)
    mat0, mat1 = splitdata(dataset, bestindex, bestvalue)
    if np.shape(mat0)[0] < toln or np.shape(mat1)[0] < toln:
        return None, leaftype(dataset)
    return bestindex, bestvalue


def istree(obj):
    return isinstance(obj, dict)


def getmean(tree):
    if istree(tree['right']):
        tree['right'] = getmean(tree['right'])
    if istree(tree['left']):
        tree['left'] = getmean(tree['left'])
    return (tree['left'] + tree['right']) / 2


def prune(tree, testdata):
    if np.shape(testdata)[0] == 0:
        return getmean(tree)
    if istree(tree['right']) or istree(tree['left']):
        lset, rset = splitdata(testdata, tree['spind'], tree['spval'])
        if istree(tree['left']):
            tree['left'] = prune(tree['left'], lset)
        if istree(tree['right']):
            tree['right'] = prune(tree['right'], rset)
    if not istree(tree['right']) and not istree(tree['left']):
        lset, rset = splitdata(testdata, tree['spind'], tree['spval'])
        errorfront = np.sum(np.power(lset[:, -1] - tree['left'], 2)) + np.sum(np.power(rset[:, -1] - tree['right'], 2))
        treemean = getmean(tree)
        errorback = np.sum(np.power(testdata[:-1] - treemean, 2))
        if errorback <= errorfront:
            print('merging')
            return treemean
        else:
            return tree
    else:
        return tree


def plotDataSet(filename, data):
    dataMat = loaddata(filename)
    n = len(dataMat)
    xcord = []
    ycord = []
    for i in range(n):
        xcord.append(dataMat[i][0])
        ycord.append(dataMat[i][-1])
    n = np.shape(data)[0]
    xtest = []
    ytest = []
    for i in range(n):
        xtest.append(data[i, 0])
        ytest.append(data[i, -1])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(xcord, ycord, s=20, c='blue', alpha=.5)
    ax.scatter(xtest, ytest, s=20, c='red', alpha=.5)
    plt.title('DataSet')
    plt.xlabel('X')
    plt.show()


def getvalue(tree, data):
    if isinstance(tree, dict):
        if data[int(tree['spind'])] > float(tree['spval']):
            return getvalue(tree['left'], data)
        else:
            return getvalue(tree['right'], data)
    else:
        return tree


def getlinearvalue(tree, data):
    if isinstance(tree, dict):
        if data[int(tree['spind'])] > float(tree['spval']):
            return getlinearvalue(tree['left'], data)
        else:
            return getlinearvalue(tree['right'], data)
    else:
        inputdata = np.mat(data)
        datamat = np.mat(np.ones((len(data), 1)))
        datamat[1:, :] = inputdata[:, 0:np.shape(inputdata)[1] - 1]
        return tree.T * datamat


def linearsolve(dataset):
    m, n = np.shape(dataset)
    x = np.mat(np.ones((m, n)))
    y = np.mat((np.ones((m, 1))))
    x[:, 1:n] = dataset[:, 0:n - 1]
    y = dataset[:, -1]
    xtx = x.T * x
    if np.linalg.det(xtx) == 0.0:
        print('false')
    ws = xtx.I * (x.T * y)
    return ws, x, y


def modelleaf(dataset):
    ws, x, y = linearsolve(dataset)
    return ws


def modleerr(dataset):
    ws, x, y = linearsolve(dataset)
    yhat = x * ws
    return np.sum(np.power(y - yhat, 2))


def regtreeeval(model, indat):
    return float(model)


if __name__ == "__main__":
    traindata = loaddata('exp2.txt')
    trainmat = np.mat(traindata)
    # plotDataSet('exp2.txt', trainmat)
    tree = creattree(trainmat, modelleaf, modleerr)
    print(getlinearvalue(tree, traindata[1]))
    # traindata = loaddata('ex2.txt')
    # testdata = loaddata('ex2test.txt')
    # mymat = np.mat(traindata)
    # testmat = np.mat(testdata)
    # tree = creattree(mymat)
    # a = copy.deepcopy(tree)
    # prunetree = prune(tree, testmat)
    # print(prunetree == a)
    # datamat = testmat.copy()
    # for i in range(np.shape(datamat)[0]):
    #     datamat[i, -1] = getvalue(prunetree, testdata[i])
    # plotDataSet('ex2test.txt', datamat)

总结

书中的监督学习到这里就结束了,接下来就是没有label的的数据,第一篇应该是K-聚类算法,继续加油