Titanic数据集来自Kaggle竞赛平台的入门项目Titanic: Machine Learning from Disaster,数据记录了泰坦尼克号处女航撞上冰山沉没北大西洋时,不同年龄、性别和社会地位的乘客及船员的生存情况。数据一共包含两个文件,分别是训练数据(train.csv)和测试数据(test.csv)。数据经过适当的探索分析和预处理后,可开展泰坦尼克号乘客生存预测。数据字段及具体含义如下:
import numpy as np
import pandas as pd
train = pd.read_csv('./titanic/train.csv')
test = pd.read_csv('./titanic/test.csv')
print('训练数据集: ', train.shape, '测试数据集: ', test.shape)
训练数据集: (891, 12) 测试数据集: (418, 11)
合并数据,方便统一进行数据清洗
data = train.append(test, ignore_index = True)
print('合并后数据集: ', data.shape)
合并后数据集: (1309, 12)
查看数据导入情况
# 查看前5行
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
# 查看后5行
data.tail()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1304 | 1305 | NaN | 3 | Spector, Mr. Woolf | male | NaN | 0 | 0 | A.5. 3236 | 8.0500 | NaN | S |
| 1305 | 1306 | NaN | 1 | Oliva y Ocana, Dona. Fermina | female | 39.0 | 0 | 0 | PC 17758 | 108.9000 | C105 | C |
| 1306 | 1307 | NaN | 3 | Saether, Mr. Simon Sivertsen | male | 38.5 | 0 | 0 | SOTON/O.Q. 3101262 | 7.2500 | NaN | S |
| 1307 | 1308 | NaN | 3 | Ware, Mr. Frederick | male | NaN | 0 | 0 | 359309 | 8.0500 | NaN | S |
| 1308 | 1309 | NaN | 3 | Peter, Master. Michael J | male | NaN | 1 | 1 | 2668 | 22.3583 | NaN | C |
#查看数据维度
data.shape
(1309, 12)
# 浏览数据集整体情况
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1309 entries, 0 to 1308 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PassengerId 1309 non-null int64 1 Survived 891 non-null float64 2 Pclass 1309 non-null int64 3 Name 1309 non-null object 4 Sex 1309 non-null object 5 Age 1046 non-null float64 6 SibSp 1309 non-null int64 7 Parch 1309 non-null int64 8 Ticket 1309 non-null object 9 Fare 1308 non-null float64 10 Cabin 295 non-null object 11 Embarked 1307 non-null object dtypes: float64(3), int64(4), object(5) memory usage: 122.8+ KB
# 查看数据集统计信息
data.describe()
| PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|---|
| count | 1309.000000 | 891.000000 | 1309.000000 | 1046.000000 | 1309.000000 | 1309.000000 | 1308.000000 |
| mean | 655.000000 | 0.383838 | 2.294882 | 29.881138 | 0.498854 | 0.385027 | 33.295479 |
| std | 378.020061 | 0.486592 | 0.837836 | 14.413493 | 1.041658 | 0.865560 | 51.758668 |
| min | 1.000000 | 0.000000 | 1.000000 | 0.170000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 328.000000 | 0.000000 | 2.000000 | 21.000000 | 0.000000 | 0.000000 | 7.895800 |
| 50% | 655.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
| 75% | 982.000000 | 1.000000 | 3.000000 | 39.000000 | 1.000000 | 0.000000 | 31.275000 |
| max | 1309.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 |
#查看数据缺失情况
data.isnull().sum()
PassengerId 0 Survived 418 Pclass 0 Name 0 Sex 0 Age 263 SibSp 0 Parch 0 Ticket 0 Fare 1 Cabin 1014 Embarked 2 dtype: int64
发现:1)Survived字段的缺失来自于测试集;2)Age和Cabin字段存在较多的数据缺失,Age字段可以尝试一定的缺失值补全,Cabin字段由于缺失值过多可以考虑删去。
# 使用箱线图刻画Age变量的分布,查看异常点
import seaborn as sns
sns.boxplot(x = 'Survived', y ='Age', data = data)
<matplotlib.axes._subplots.AxesSubplot at 0x1a189b8668>
发现:1)数据集中乘客平均年龄在30岁左右;2)数据集中乘客年龄存在异常点(高龄乘客)。3)平均而言,幸存者相对更年轻。
# 均值填充
# Age
print('年龄均值:', data['Age'].mean())
data['Age'] = data['Age'].fillna(data['Age'].mean())
# Fare
print('旅客票价均值:', data['Fare'].mean())
data['Fare'] = data['Fare'].fillna(data['Fare'].mean())
年龄均值: 29.881137667304014 旅客票价均值: 33.2954792813456
# 众数填充
# Embarked
print(data['Embarked'].value_counts())
data['Embarked'] = data['Embarked'].fillna('S')
S 914 C 270 Q 123 Name: Embarked, dtype: int64
from sklearn.impute import KNNImputer
#select the numeric columns and transform it to numpy
n_train= train[['Age','SibSp','Parch','Fare']].to_numpy()
imputer = KNNImputer(n_neighbors=2, weights="uniform")
print('处理前缺失值个数:', np.isnan(n_train).sum())
n_train_impute = imputer.fit_transform(n_train)
print('处理后缺失值个数:', np.isnan(n_train_impute).sum())
print('处理后各均值:', np.mean(n_train_impute,axis=0))
处理前缺失值个数: 177 处理后缺失值个数: 0 处理后各均值: [30.41007856 0.52300786 0.38159371 32.20420797]
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1309 entries, 0 to 1308 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PassengerId 1309 non-null int64 1 Survived 891 non-null float64 2 Pclass 1309 non-null int64 3 Name 1309 non-null object 4 Sex 1309 non-null object 5 Age 1309 non-null float64 6 SibSp 1309 non-null int64 7 Parch 1309 non-null int64 8 Ticket 1309 non-null object 9 Fare 1309 non-null float64 10 Cabin 295 non-null object 11 Embarked 1309 non-null object dtypes: float64(3), int64(4), object(5) memory usage: 122.8+ KB
# Sex
data['Sex'].head()
0 male 1 female 2 female 3 female 4 male Name: Sex, dtype: object
# Sex
Sex_map = {
'female': 0,
'male': 1
}
data['Sex'] = data['Sex'].map(Sex_map)
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
data = data.join(pd.get_dummies(data['Embarked'], prefix = 'Embarked'))
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | Embarked_C | Embarked_Q | Embarked_S | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 0 | 0 | 1 |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 | 0 | 0 |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 0 | 0 | 1 |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 0 | 0 | 1 |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 0 | 0 | 1 |
data = data.join(pd.get_dummies(data['Pclass'], prefix = 'Pclass'))
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | Embarked_C | Embarked_Q | Embarked_S | Pclass_1 | Pclass_2 | Pclass_3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 | 0 | 0 | 1 | 0 | 0 |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 0 | 0 | 1 | 1 | 0 | 0 |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 |
def get_title(name):
str1 = name.split(',')[1]
str2 = str1.split('.')[0]
str3 = str2.strip()
return str3
data['Title'] = data['Name'].map(get_title)
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | Embarked_C | Embarked_Q | Embarked_S | Pclass_1 | Pclass_2 | Pclass_3 | Title | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 | Mr |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 | 0 | 0 | 1 | 0 | 0 | Mrs |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 | Miss |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 0 | 0 | 1 | 1 | 0 | 0 | Mrs |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 0 | 0 | 1 | 0 | 0 | 1 | Mr |
data['Title'].value_counts()
Mr 757 Miss 260 Mrs 197 Master 61 Dr 8 Rev 8 Col 4 Ms 2 Major 2 Mlle 2 Jonkheer 1 Sir 1 Dona 1 Mme 1 the Countess 1 Capt 1 Don 1 Lady 1 Name: Title, dtype: int64
变量Title中部分头衔并不常见,因此进行汇总处理,得到共6类,分别是Officer, Royalty, Mrs, Miss, Mr, Master
Title_map = {
'Mr': 'Mr',
'Miss': 'Miss',
'Mrs': 'Mrs',
'Master': 'Master',
'Rev': 'Officer',
'Dr': 'Officer',
'Col': 'Officer',
'Ms': 'Mrs',
'Mlle': 'Miss',
'Major': 'Officer',
'Dona': 'Royalty',
'Sir': 'Royalty',
'Capt': 'Officer',
'the Countess': 'Royalty',
'Don': 'Royalty',
'Lady': 'Royalty',
'Mme': 'Mrs',
'Jonkheer': 'Royalty'
}
data['Title'] = data['Title'].map(Title_map)
data['Title'].value_counts()
data = data.join(pd.get_dummies(data['Title'], prefix = 'Title'))
data['Family'] = data['SibSp'] + data['Parch'] + 1
data['FamilySingle'] = data['Family'].map(lambda a:1 if a == 1 else 0)
data['FamilySmall'] = data['Family'].map(lambda a:1 if 2 <= a <= 4 else 0)
data['FamilyLarge'] = data['Family'].map(lambda a:1 if 5 <= a else 0)
data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | ... | Title_Master | Title_Miss | Title_Mr | Title_Mrs | Title_Officer | Title_Royalty | Family | FamilySingle | FamilySmall | FamilyLarge | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | ... | 0 | 0 | 1 | 0 | 0 | 0 | 2 | 0 | 1 | 0 |
| 1 | 2 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 2 | 0 | 1 | 0 |
| 2 | 3 | 1.0 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 |
| 3 | 4 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 2 | 0 | 1 | 0 |
| 4 | 5 | 0.0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | ... | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 | 0 | 0 |
5 rows × 29 columns
可以利用sklearn.preprocessing包实现数据标准化操作。其中,StandardScaler可用来做正态标准化,MinMaxScaler可用来做最小最大标准化。
#选取range不在[0,1]的变量
print(data.describe())
data_r = data[['Age','SibSp','Parch','Fare','Family']]
PassengerId Survived Pclass Sex Age \
count 1309.000000 891.000000 1309.000000 1309.000000 1309.000000
mean 655.000000 0.383838 2.294882 0.644003 29.881138
std 378.020061 0.486592 0.837836 0.478997 12.883193
min 1.000000 0.000000 1.000000 0.000000 0.170000
25% 328.000000 0.000000 2.000000 0.000000 22.000000
50% 655.000000 0.000000 3.000000 1.000000 29.881138
75% 982.000000 1.000000 3.000000 1.000000 35.000000
max 1309.000000 1.000000 3.000000 1.000000 80.000000
SibSp Parch Fare Embarked_C Embarked_Q ... \
count 1309.000000 1309.000000 1309.000000 1309.000000 1309.000000 ...
mean 0.498854 0.385027 33.295479 0.206264 0.093965 ...
std 1.041658 0.865560 51.738879 0.404777 0.291891 ...
min 0.000000 0.000000 0.000000 0.000000 0.000000 ...
25% 0.000000 0.000000 7.895800 0.000000 0.000000 ...
50% 0.000000 0.000000 14.454200 0.000000 0.000000 ...
75% 1.000000 0.000000 31.275000 0.000000 0.000000 ...
max 8.000000 9.000000 512.329200 1.000000 1.000000 ...
Title_Master Title_Miss Title_Mr Title_Mrs Title_Officer \
count 1309.000000 1309.000000 1309.000000 1309.000000 1309.000000
mean 0.046600 0.200153 0.578304 0.152788 0.017571
std 0.210862 0.400267 0.494019 0.359921 0.131435
min 0.000000 0.000000 0.000000 0.000000 0.000000
25% 0.000000 0.000000 0.000000 0.000000 0.000000
50% 0.000000 0.000000 1.000000 0.000000 0.000000
75% 0.000000 0.000000 1.000000 0.000000 0.000000
max 1.000000 1.000000 1.000000 1.000000 1.000000
Title_Royalty Family FamilySingle FamilySmall FamilyLarge
count 1309.000000 1309.000000 1309.000000 1309.000000 1309.000000
mean 0.004584 1.883881 0.603514 0.333843 0.062643
std 0.067573 1.583639 0.489354 0.471765 0.242413
min 0.000000 1.000000 0.000000 0.000000 0.000000
25% 0.000000 1.000000 0.000000 0.000000 0.000000
50% 0.000000 1.000000 1.000000 0.000000 0.000000
75% 0.000000 2.000000 1.000000 1.000000 0.000000
max 1.000000 11.000000 1.000000 1.000000 1.000000
[8 rows x 24 columns]
#正态标准化
from sklearn import preprocessing
z_scaler = preprocessing.StandardScaler().fit(data_r)
print('mean:', z_scaler.mean_)
print('std:', z_scaler.scale_)
print(z_scaler.transform(data_r))
mean: [29.88113767 0.49885409 0.38502674 33.29547928 1.88388083] std: [12.8782713 1.04126043 0.86522959 51.71911251 1.58303407] [[-0.61197171 0.48128777 -0.4449995 -0.50359486 0.07335229] [ 0.63043107 0.48128777 -0.4449995 0.73450256 0.07335229] [-0.30137101 -0.47908676 -0.4449995 -0.49054359 -0.55834605] ... [ 0.66925616 -0.47908676 -0.4449995 -0.50359486 -0.55834605] [ 0. -0.47908676 -0.4449995 -0.48812669 -0.55834605] [ 0. 0.48128777 0.71076309 -0.21147268 0.70505064]]
#最小最大标准化
m_scaler = preprocessing.MinMaxScaler().fit(data_r)
print(m_scaler.transform(data_r))
[[0.27345609 0.125 0. 0.01415106 0.1 ] [0.473882 0.125 0. 0.13913574 0.1 ] [0.32356257 0. 0. 0.01546857 0. ] ... [0.48014531 0. 0. 0.01415106 0. ] [0.3721801 0. 0. 0.01571255 0. ] [0.3721801 0.125 0.11111111 0.0436405 0.2 ]]