Telco Customer Churn

Telcom Customer Churn

每行代表一个客户每列是一个特征,原始数据包含 7043 行 21 列. “Churn” 是标签.

  • customerID : 客户 ID

  • gender : 客户性别

  • SeniorCitizen : 用户是否退休 (1, 0)

  • Partner : 客户是否有合作伙伴 (Yes, No)

  • Dependents : 客户是否有家属 (Yes, No)

  • tenure : 客户留存月数

  • PhoneService : 客户是否有电话服务 (Yes, No)

  • MultipleLines : 客户是否有多条线路 (Yes, No, No phone service)

  • InternetService : 客户网络服务提供商 (DSL, Fiber optic, No)

  • OnlineSecurity : 客户是否有在线安全 (Yes, No, No internet service)

  • OnlineBackup : 客户是否有在线备份 (Yes, No, No internet service)

  • DeviceProtection : 客户是否有设备保护 (Yes, No, No internet service)

  • TechSupport : 客户是否有技术支持 (Yes, No, No internet service)

  • StreamingTV : 客户是否有流媒体电视 (Yes, No, No internet service)

  • StreamingMovies : 客户是否有流媒体电影 (Yes, No, No internet service)

  • Contract : 客户合同期限 (Month-to-month, One year, Two year)

  • PaperlessBilling : 客户是否有无纸化计费 (Yes, No)

  • PaymentMethod : 客户付款方式 (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))

  • MonthlyCharges : 客户月付金额

  • TotalCharges : 客户付款总金额

  • Churn : 客户是否流失 (Yes or No)

1. Load libraries and read the data

1.1. Load libraries

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd

import numpy as np

import seaborn as sns

import matplotlib.pyplot as plt

import matplotlib.ticker as mtick

import scipy.stats as ss

%matplotlib inline

import itertools

from sklearn.decomposition import PCA

from sklearn.preprocessing import StandardScaler, LabelEncoder

from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split, GridSearchCV, RandomizedSearchCV

from sklearn.metrics import auc, precision_score, recall_score, confusion_matrix, roc_curve, precision_recall_curve, accuracy_score, roc_auc_score, classification_report

from datetime import datetime

import lightgbm as lgbm

import warnings

from scipy.stats import randint as sp_randint

from scipy.stats import uniform as sp_uniform

import warnings

warnings.filterwarnings('ignore') #ignore warning messages

from contextlib import contextmanager



@contextmanager

def timer(title):

t0 = time.time()

yield

print("{} - done in {:.0f}s".format(title, time.time() - t0))

1.2. Read the data

1
data = pd.read_csv(r"./dataset/WA_Fn-UseC_-Telco-Customer-Churn.csv")

1.3. Head, describe, shape and info

1
2
3
4
5
6
7
display(data.head())

display(data.describe())

display(data.shape)

display(data.info())

customerID gender SeniorCitizen Partner Dependents tenure PhoneService MultipleLines InternetService OnlineSecurity ... DeviceProtection TechSupport StreamingTV StreamingMovies Contract PaperlessBilling PaymentMethod MonthlyCharges TotalCharges Churn
0 7590-VHVEG Female 0 Yes No 1 No No phone service DSL No ... No No No No Month-to-month Yes Electronic check 29.85 29.85 No
1 5575-GNVDE Male 0 No No 34 Yes No DSL Yes ... Yes No No No One year No Mailed check 56.95 1889.5 No
2 3668-QPYBK Male 0 No No 2 Yes No DSL Yes ... No No No No Month-to-month Yes Mailed check 53.85 108.15 Yes
3 7795-CFOCW Male 0 No No 45 No No phone service DSL Yes ... Yes Yes No No One year No Bank transfer (automatic) 42.30 1840.75 No
4 9237-HQITU Female 0 No No 2 Yes No Fiber optic No ... No No No No Month-to-month Yes Electronic check 70.70 151.65 Yes

5 rows × 21 columns

SeniorCitizen tenure MonthlyCharges
count 7043.000000 7043.000000 7043.000000
mean 0.162147 32.371149 64.761692
std 0.368612 24.559481 30.090047
min 0.000000 0.000000 18.250000
25% 0.000000 9.000000 35.500000
50% 0.000000 29.000000 70.350000
75% 0.000000 55.000000 89.850000
max 1.000000 72.000000 118.750000
(7043, 21)


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7043 entries, 0 to 7042
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        7043 non-null   object 
 1   gender            7043 non-null   object 
 2   SeniorCitizen     7043 non-null   int64  
 3   Partner           7043 non-null   object 
 4   Dependents        7043 non-null   object 
 5   tenure            7043 non-null   int64  
 6   PhoneService      7043 non-null   object 
 7   MultipleLines     7043 non-null   object 
 8   InternetService   7043 non-null   object 
 9   OnlineSecurity    7043 non-null   object 
 10  OnlineBackup      7043 non-null   object 
 11  DeviceProtection  7043 non-null   object 
 12  TechSupport       7043 non-null   object 
 13  StreamingTV       7043 non-null   object 
 14  StreamingMovies   7043 non-null   object 
 15  Contract          7043 non-null   object 
 16  PaperlessBilling  7043 non-null   object 
 17  PaymentMethod     7043 non-null   object 
 18  MonthlyCharges    7043 non-null   float64
 19  TotalCharges      7043 non-null   object 
 20  Churn             7043 non-null   object 
dtypes: float64(1), int64(2), object(18)
memory usage: 1.1+ MB



None

1.4. Reassign target, encode variables and replace missing values

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Reassign target 将是否替换为0,1

data.Churn.replace(to_replace = dict(Yes = 1, No = 0), inplace = True)



# Encode as object 将指示位替换为对象类型

col_name = ['SeniorCitizen', 'Churn']

data[col_name] = data[col_name].astype(object)



# Encode as float 将钱数替换为float类型

data['TotalCharges'] = data['TotalCharges'].replace(" ", 0).astype('float64')

2. Exploratory Data Analysis (EDA)

1
2
3
churn = data[(data['Churn'] != 0)]      # 取流失的客户全部数据

no_churn = data[(data['Churn'] == 0)] # 取留存的客户全部数据

2.1. Target distribution (number and %)

1
ax = sns.catplot(y="Churn", kind="count", data=data, height=2.6, aspect=2.5, palette ={0 : 'lightblue', 1 : 'gold'}, orient='h')

fLpqKK.png

1
2
3
4
5
6
7
8
9
ax = (data['Churn'].value_counts()*100.0 /len(data))\

.plot.pie(autopct='%.1f%%', labels = ['No', 'Yes'],figsize =(5,5), colors=['lightblue', 'gold'], fontsize = 12)

ax.yaxis.set_major_formatter(mtick.PercentFormatter())

ax.set_ylabel('Churn',fontsize = 12)

ax.set_title('% of Churn', fontsize = 12)
Text(0.5, 1.0, '% of Churn')

fLpLDO.png

2.2. Numeric features : Seaborn

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def plot_distribution_num(data_select) : 

sns.set_style("ticks")

s = sns.FacetGrid(data, hue = 'Churn',aspect = 2.5, palette ={0 : 'lightblue', 1 : 'gold'})

s.map(sns.kdeplot, data_select, shade = False, alpha = 0.8)

s.set(xlim=(0, data[data_select].max()))

s.add_legend()

s.set_axis_labels(data_select, 'proportion')

s.fig.suptitle(data_select)

plt.show()



plot_distribution_num('tenure')

plot_distribution_num('MonthlyCharges')

plot_distribution_num('TotalCharges')

fLpvUH.png

fLpjVe.png

fLpObD.png

2.3. Numeric features : Correlation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
df_quant = data.select_dtypes(exclude=[object])

df_quant.head()

corr_quant = df_quant.corr()



fig, ax = plt.subplots(figsize=(15, 10))

ax = sns.heatmap(corr_quant, annot=True, cmap = 'viridis', linewidths = .1, linecolor = 'grey', fmt=".2f")

ax.invert_yaxis()

ax.set_title("Correlation")

plt.show()

fLpx5d.png

2.4. Object features : Seaborn

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def plot_distribution_cat(feature1,feature2, df): 

plt.figure(figsize=(18,5))

plt.subplot(121)

s = sns.countplot(x = feature1, hue='Churn', data = df,

palette = {0 : 'lightblue', 1 :'gold'}, alpha = 0.8,

linewidth = 0.4, edgecolor='grey')

s.set_title(feature1)

for p in s.patches:

s.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.15, p.get_height()+30))



plt.subplot(122)

s = sns.countplot(x = feature2, hue='Churn', data = df,

palette = {0 : 'lightblue', 1 :'gold'}, alpha = 0.8,

linewidth = 0.4, edgecolor='grey')

s.set_title(feature2)

for p in s.patches:

s.annotate('{:.0f}'.format(p.get_height()), (p.get_x()+0.15, p.get_height()+30))

plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
plot_distribution_cat('SeniorCitizen', 'gender', data)

plot_distribution_cat('Partner', 'Dependents', data)

plot_distribution_cat('MultipleLines', 'InternetService', data)

plot_distribution_cat('OnlineSecurity', 'TechSupport', data)

plot_distribution_cat('DeviceProtection', 'StreamingTV',data)

plot_distribution_cat('StreamingMovies', 'PaperlessBilling',data)

plot_distribution_cat('PaymentMethod', 'Contract',data)

fL9SPA.png

fL9p8I.png

fL992t.png

fL9CxP.png

fL9Fr8.png

fL9kqS.png

fL9EVg.png

3. Feature engineering and selection

3.1. New features

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
data.loc[:,'Engaged']=1 

data.loc[(data['Contract']=='Month-to-month'),'Engaged']=0



data.loc[:,'YandNotE']=0

data.loc[(data['SeniorCitizen']==0) & (data['Engaged']==0),'YandNotE']=1



data.loc[:,'ElectCheck']=0

data.loc[(data['PaymentMethod']=='Electronic check') & (data['Engaged']==0),'ElectCheck']=1



data.loc[:,'fiberopt']=1

data.loc[(data['InternetService']!='Fiber optic'),'fiberopt']=0



data.loc[:,'StreamNoInt']=1

data.loc[(data['StreamingTV']!='No internet service'),'StreamNoInt']=0



data.loc[:,'NoProt']=1

data.loc[(data['OnlineBackup']!='No') | (data['DeviceProtection']!='No') | (data['TechSupport']!='No'),'NoProt']=0



data['TotalServices'] = (data[['PhoneService', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies']]== 'Yes').sum(axis=1)
1
2
# 对留存月份进行分箱,返回所在箱
data['tenure'] = pd.cut(data['tenure'], 3)

3.2. Drop some features

1
data.columns
Index(['customerID', 'gender', 'SeniorCitizen', 'Partner', 'Dependents',
       'tenure', 'PhoneService', 'MultipleLines', 'InternetService',
       'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport',
       'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling',
       'PaymentMethod', 'MonthlyCharges', 'TotalCharges', 'Churn', 'Engaged',
       'YandNotE', 'ElectCheck', 'fiberopt', 'StreamNoInt', 'NoProt',
       'TotalServices'],
      dtype='object')
1
2
3
4
5
6
7
8
9
data = data.drop(columns = [

'Contract',

'DeviceProtection',

'Partner'

])

3.3. Features encoding and scaling

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# 客户ID

Id_col = ['customerID']

# 标签

target_col = ["Churn"]

# 对象型数据列

cat_cols = data.nunique()[data.nunique() < 10].keys().tolist()

cat_cols = [x for x in cat_cols if x not in target_col]

# 数值型数据列

num_cols = [x for x in data.columns if x not in cat_cols + target_col + Id_col]

# 二值列

bin_cols = data.nunique()[data.nunique() == 2].keys().tolist()

# 多值列

multi_cols = [i for i in cat_cols if i not in bin_cols]



# 对二值列进行标签编码

le = LabelEncoder()

for i in bin_cols :

data[i] = le.fit_transform(data[i])



# 对多值列进行独热编码

data = pd.get_dummies(data = data,columns = multi_cols )



# 将数值型数据标准化

std = StandardScaler()

scaled = std.fit_transform(data[num_cols])

scaled = pd.DataFrame(scaled,columns=num_cols)



# 将数值型元数据替换为标准化后的数值数据

df_data_og = data.copy()

data = data.drop(columns = num_cols,axis = 1)

data = data.merge(scaled,left_index=True,right_index=True,how = "left")

data = data.drop(['customerID'],axis = 1)

3.4. Correlation Heatmap

1
2
3
4
5
6
7
8
9
10
11
12
13
def correlation_plot(data):

data = data.corr()

fig, ax = plt.subplots(figsize=(15, 10))

ax = sns.heatmap(data, annot=False, cmap = 'viridis', linewidths = .1, linecolor = 'grey', fmt=".2f")

ax.invert_yaxis()

ax.set_title("Correlation")

plt.show()
1
correlation_plot(data)

fL9VaQ.png

3.5. Remove collinear features

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 相关特征进行删除的阈值

threshold = 0.9



# 对相关性矩阵取绝对值

corr_matrix = data.corr().abs()

display(corr_matrix.head().append(corr_matrix.tail()))



# 取相关性矩阵上三角 np.triu 此处方法先将数据全部置1然后取上三角,此时下三角全为0,将0-1转换为bool值,以选取True所在单元格的数据

upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))

display(upper.head().append(upper.tail()))



# 取阈值之上的特征并删除

to_drop = [column for column in upper.columns if any(upper[column] > threshold)]



print('There are %d columns to remove :' % (len(to_drop)))



data = data.drop(columns = to_drop)



to_drop

gender SeniorCitizen Dependents PhoneService PaperlessBilling Churn Engaged YandNotE ElectCheck fiberopt ... TotalServices_0 TotalServices_1 TotalServices_2 TotalServices_3 TotalServices_4 TotalServices_5 TotalServices_6 TotalServices_7 MonthlyCharges TotalCharges
gender 1.000000 0.001874 0.010517 0.006488 0.011754 0.008612 0.003386 0.003776 0.006969 0.011286 ... 0.012379 0.006565 0.001029 0.001161 0.005514 0.002264 0.017296 0.013176 0.014569 0.000080
SeniorCitizen 0.001874 1.000000 0.211185 0.008576 0.156530 0.150889 0.138360 0.386482 0.186468 0.255338 ... 0.003738 0.122491 0.041459 0.046895 0.050380 0.029801 0.002747 0.014319 0.220173 0.103006
Dependents 0.010517 0.211185 1.000000 0.001762 0.111377 0.164221 0.231720 0.111982 0.169907 0.165818 ... 0.023303 0.059161 0.084852 0.047922 0.013134 0.029117 0.029168 0.048430 0.113890 0.062078
PhoneService 0.006488 0.008576 0.001762 1.000000 0.016505 0.011942 0.000742 0.000083 0.004519 0.289999 ... 0.327354 0.107222 0.065522 0.069257 0.009174 0.013544 0.047230 0.063979 0.247398 0.113214
PaperlessBilling 0.011754 0.156530 0.111377 0.016505 1.000000 0.191825 0.169096 0.070549 0.197873 0.326853 ... 0.006482 0.251039 0.068116 0.068790 0.082407 0.070216 0.058396 0.011691 0.352150 0.158574
TotalServices_5 0.002264 0.029801 0.029117 0.013544 0.070216 0.037420 0.164026 0.141655 0.075395 0.115043 ... 0.039097 0.250156 0.148032 0.151906 0.153700 1.000000 0.103519 0.071270 0.299248 0.315356
TotalServices_6 0.017296 0.002747 0.029168 0.047230 0.058396 0.089768 0.217169 0.176288 0.116575 0.064497 ... 0.030421 0.194641 0.115181 0.118195 0.119591 0.103519 1.000000 0.055454 0.289082 0.391785
TotalServices_7 0.013176 0.014319 0.048430 0.063979 0.011691 0.091806 0.202449 0.162530 0.109766 0.041263 ... 0.020944 0.134005 0.079299 0.081374 0.082335 0.071270 0.055454 1.000000 0.246353 0.378530
MonthlyCharges 0.014569 0.220173 0.113890 0.247398 0.352150 0.193356 0.060165 0.048075 0.202893 0.787066 ... 0.141994 0.724004 0.010335 0.116311 0.249383 0.299248 0.289082 0.246353 1.000000 0.651174
TotalCharges 0.000080 0.103006 0.062078 0.113214 0.158574 0.198324 0.444255 0.413707 0.212875 0.361655 ... 0.097310 0.492495 0.196423 0.059875 0.151065 0.315356 0.391785 0.378530 0.651174 1.000000

10 rows × 50 columns

gender SeniorCitizen Dependents PhoneService PaperlessBilling Churn Engaged YandNotE ElectCheck fiberopt ... TotalServices_0 TotalServices_1 TotalServices_2 TotalServices_3 TotalServices_4 TotalServices_5 TotalServices_6 TotalServices_7 MonthlyCharges TotalCharges
gender NaN 0.001874 0.010517 0.006488 0.011754 0.008612 0.003386 0.003776 0.006969 0.011286 ... 0.012379 0.006565 0.001029 0.001161 0.005514 0.002264 0.017296 0.013176 0.014569 0.000080
SeniorCitizen NaN NaN 0.211185 0.008576 0.156530 0.150889 0.138360 0.386482 0.186468 0.255338 ... 0.003738 0.122491 0.041459 0.046895 0.050380 0.029801 0.002747 0.014319 0.220173 0.103006
Dependents NaN NaN NaN 0.001762 0.111377 0.164221 0.231720 0.111982 0.169907 0.165818 ... 0.023303 0.059161 0.084852 0.047922 0.013134 0.029117 0.029168 0.048430 0.113890 0.062078
PhoneService NaN NaN NaN NaN 0.016505 0.011942 0.000742 0.000083 0.004519 0.289999 ... 0.327354 0.107222 0.065522 0.069257 0.009174 0.013544 0.047230 0.063979 0.247398 0.113214
PaperlessBilling NaN NaN NaN NaN NaN 0.191825 0.169096 0.070549 0.197873 0.326853 ... 0.006482 0.251039 0.068116 0.068790 0.082407 0.070216 0.058396 0.011691 0.352150 0.158574
TotalServices_5 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN 0.103519 0.071270 0.299248 0.315356
TotalServices_6 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN 0.055454 0.289082 0.391785
TotalServices_7 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN 0.246353 0.378530
MonthlyCharges NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN 0.651174
TotalCharges NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

10 rows × 50 columns

There are 8 columns to remove :


['MultipleLines_No phone service',
 'InternetService_Fiber optic',
 'InternetService_No',
 'OnlineSecurity_No internet service',
 'OnlineBackup_No internet service',
 'TechSupport_No internet service',
 'StreamingTV_No internet service',
 'StreamingMovies_No internet service']
1
correlation_plot(data)

fL9mPs.png

4. Prepare dataset and stylized report

4.1. Define (X, y)

1
2
3
4
5
6
7
y = np.array(data.Churn.tolist())

data = data.drop('Churn', 1)

features = data.columns

X = np.array(data.values)

4.2. Train_test_split

1
2
3
4
5
# 训练数据分割训练集测试集

random_state = 51

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = random_state)

4.3. Stylized report with sns (confusion matrix, roc, precision-recall, etc…)

为了衡量模型的性能,我们需要几个元素:

这部分必不可少

  • Confusion matrix : 也称为误差矩阵,使算法的性能可视化:
* true positive (TP) : 识别正确,识别样本为正例,实际为正例

* true negative (TN) : 识别正确,识别样本为反例,实际为反例

* false positive (FP) : 识别错误,识别样本为正例,实际为反例

* false negative (FN) : 识别错误,识别样本为反例,实际为正例
  • Metrics :
* Accuracy : (TP +TN) / (TP + TN + FP +FN)

* Precision : TP / (TP + FP)

* Recall : TP / (TP + FN)

* F1 score : 2 x ((Precision x Recall) / (Precision + Recall))
  • Roc Curve : ROC 曲线是通过在各种阈值设置下绘制真阳性率 (TPR) 与假阳性率 (FPR) 来创建的。
  • Precision Recall Curve : 显示了不同阈值下精确率和召回率之间的权衡
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def model_performance(model):    

# 混淆矩阵

sns.set()

C2= confusion_matrix(y_test, y_pred, labels=[0, 1])

sns.heatmap(C2,annot=True)

plt.show()



# ROC曲线

fpr, tpr, threshold = roc_curve(y_test, y_score)

roc_auc = auc(fpr, tpr)



plt.title('Receiver Operating Characteristic:'+ str(round(roc_auc_score(y_test, y_score), 3)))

plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')

plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % roc_auc)

plt.legend(loc='lower right')

plt.plot([0, 1], [0, 1], 'r ')

plt.xlim([0, 1])

plt.ylim([0, 1])

plt.ylabel('True Positive Rate')

plt.xlabel('False Positive Rate')

plt.show()



# Recall曲线

precision, recall, thresholds = precision_recall_curve(y_test, y_score)



plt.plot(precision, recall)

plt.title('Precision-Recall curve')

plt.ylabel('precision')

plt.xlabel('recall')

plt.show()



# 特征重要性

importances = eval(model).feature_importances_

weights = pd.Series(importances, index=features)

weights.sort_values()[-10:].plot(kind = 'barh')

4.4. Define cross validation metrics

1
2
3
4
5
6
7
8
9
10
11
# 交叉验证指标

def cross_val_metrics(model) :

scores = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']

for sc in scores:

scores = cross_val_score(model, X, y, cv = 5, scoring = sc)

print('[%s] : %0.5f (+/- %0.5f)'%(sc, scores.mean(), scores.std()))

5. Light GBM Model

5.1. LightGBM - Before RandomizedSearchCV

LightGBM 是一个梯度提升框架,它使用基于树的学习算法。 它旨在分布式和高效,具有以下优点:

  • 训练速度更快,效率更高。

  • 较低的内存使用率。

  • 更好的准确性。

  • 支持并行和 GPU 学习。

  • 能够处理大规模数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
%%time

lgbm_clf = lgbm.LGBMClassifier(n_estimators=1500, random_state = 51)



lgbm_clf.fit(X_train, y_train)



y_pred = lgbm_clf.predict(X_test)

y_score = lgbm_clf.predict_proba(X_test)[:,1]



model_performance('lgbm_clf')

fL9Z5j.png

fL9nGn.png

fL9u2q.png

Wall time: 2.72 s

fLivOP.png

5.2. LightGBM - RandomizedSearchCV to optimise hyperparameters (1000 fits)

为了找到最佳超参数,我们将使用随机搜索 CV。
随机搜索是一种使用超参数的随机组合来为构建的模型寻找最佳解决方案的技术。
通常,RandomizedSearchCV 比计算所有可能组合的 GridSearchCV 更快、更准确。 使用随机网格,指定我们想要的组合数量。

  • LightGBM : Hyperparameters :
* learning_rate : 学习率

* n_estimators : 树的数量

* num_leaves : 整棵树的叶子数,默认值为31

* min_child_samples : 一片叶子中数据的最少数量。 可以用来处理过拟合

* min_child_weight : 一片叶子中的最小Hessian和

* subsample : 随机选择部分数据而不重新采样

* max_depth : 它描述了树的最大深度。 该参数用于处理模型过拟合。

* colsample_bytree : 如果该值小于1.0,LightGBM 将在每次迭代中随机选择部分特征。 例如,如果将其设置为 0.8,LightGBM 将在训练每棵树之前选择 80% 的特征

* reg_alpha : 正则化

* reg_lambda : 正则化

* early_stopping_rounds : 早停机制,此参数可以加快分析速度。 如果一个验证数据的一个指标在上个 early_stopping_round 轮次中没有改进,模型将停止训练。 这将减少过多的迭代次数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
fit_params = {"early_stopping_rounds" : 50, 

"eval_metric" : 'binary',

"eval_set" : [(X_test,y_test)],

'eval_names': ['valid'],

'verbose': 0,

'categorical_feature': 'auto'}



param_test = {'learning_rate' : [0.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.3, 0.4],

'n_estimators' : [100, 200, 300, 400, 500, 600, 800, 1000, 1500, 2000, 3000, 5000],

'num_leaves': sp_randint(6, 50),

'min_child_samples': sp_randint(100, 500),

'min_child_weight': [1e-5, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4],

'subsample': sp_uniform(loc=0.2, scale=0.8),

'max_depth': [-1, 1, 2, 3, 4, 5, 6, 7],

'colsample_bytree': sp_uniform(loc=0.4, scale=0.6),

'reg_alpha': [0, 1e-1, 1, 2, 5, 7, 10, 50, 100],

'reg_lambda': [0, 1e-1, 1, 5, 10, 20, 50, 100]}



#number of combinations

n_iter = 200



#intialize lgbm and lunch the search

lgbm_clf = lgbm.LGBMClassifier(random_state=random_state, silent=True, metric='None', n_jobs=4)

grid_search = RandomizedSearchCV(

estimator=lgbm_clf, param_distributions=param_test,

n_iter=n_iter,

scoring='accuracy',

cv=5,

refit=True,

random_state=random_state,

verbose=True)



grid_search.fit(X_train, y_train, **fit_params)

print('Best params: {} '.format(grid_search.best_params_))



opt_parameters = grid_search.best_params_
Fitting 5 folds for each of 200 candidates, totalling 1000 fits
Best params: {'colsample_bytree': 0.9970322000118677, 'learning_rate': 0.08, 'max_depth': 3, 'min_child_samples': 476, 'min_child_weight': 100.0, 'n_estimators': 500, 'num_leaves': 35, 'reg_alpha': 7, 'reg_lambda': 1, 'subsample': 0.6034489935154275} 

5.3. LightGBM - After RandomizedSearchCV

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
%%time

lgbm_clf = lgbm.LGBMClassifier(**opt_parameters)



lgbm_clf.fit(X_train, y_train)

y_pred = lgbm_clf.predict(X_test)

y_score = lgbm_clf.predict_proba(X_test)[:,1]



model_performance('lgbm_clf')

fLiqWd.png

fLijyt.png

fLiOSA.png

Wall time: 1.41 s

fLiXQI.png

5.4. LightGBM – Cross validation (5 folds)

1
cross_val_metrics(lgbm_clf)
[accuracy] : 0.80591 (+/- 0.00951)
[precision] : 0.66978 (+/- 0.02319)
[recall] : 0.53023 (+/- 0.01859)
[f1] : 0.59181 (+/- 0.01951)
[roc_auc] : 0.84634 (+/- 0.00961)

本文提及的数据集下载地址:
链接:https://pan.baidu.com/s/1ZVlY4KeAy6cvu8Aotb3CpQ
提取码:1111


Donate
  • Copyright: Copyright is owned by the author. For commercial reprints, please contact the author for authorization. For non-commercial reprints, please indicate the source.
  • Copyrights © 2019-2022 Woody
  • Visitors: | Views:

请我喝杯咖啡吧~

支付宝
微信