In [ ]:
In [1]:
# === 1. 环境与数据读取 ===
import warnings
warnings.filterwarnings('ignore')
import os
import numpy as np
import pandas as pd
# 可视化
import matplotlib.pyplot as plt
import seaborn as sns
# 显示设置
pd.set_option('display.max_columns', 100)
plt.rcParams['figure.figsize'] = (8, 5)
# 图表与导出目录
FIG_DIR = "figs"
OUT_DIR = "outputs"
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
def save_fig(fig, name, dpi=300):
"""保存 Matplotlib 图为 PNG+PDF(论文质量)。"""
png_path = os.path.join(FIG_DIR, f"{name}.png")
pdf_path = os.path.join(FIG_DIR, f"{name}.pdf")
fig.savefig(png_path, dpi=dpi, bbox_inches='tight')
fig.savefig(pdf_path, dpi=dpi, bbox_inches='tight')
print(f"[Saved] {png_path} & {pdf_path}")
# 读入数据(与 notebook 同目录或自行修改路径)
train_path = 'train.csv'
test_path = 'test.csv'
train = pd.read_csv(train_path)
test = pd.read_csv(test_path)
print("Shapes:", train.shape, test.shape)
display(train.head())
Shapes: (891, 12) (418, 12)
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
| 1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
| 3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
In [15]:
# === Academic Figure Style (run once, before any plotting) ===
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import os
# 统一字体(优先 Times New Roman,不在系统则退化到 DejaVu Serif)
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman', 'Times', 'DejaVu Serif', 'STSong']
# 统一字号(常见期刊建议范围)
BASE = 11 # 正文基准字号
mpl.rcParams.update({
'axes.titlesize': BASE + 3, # 图标题
'axes.labelsize': BASE + 1, # 坐标轴标签
'xtick.labelsize': BASE, # 坐标轴刻度
'ytick.labelsize': BASE,
'legend.fontsize': BASE,
'figure.titlesize': BASE + 3
})
# 统一线条与网格
mpl.rcParams.update({
'lines.linewidth': 1.6,
'axes.linewidth': 1.0,
'grid.linewidth': 0.6,
'xtick.major.width': 0.8,
'ytick.major.width': 0.8,
'patch.edgecolor': 'black', # 柱形图等边框
'patch.linewidth': 0.6,
'savefig.dpi': 600, # 默认保存分辨率(兜底)
'savefig.bbox': 'tight',
'pdf.fonttype': 42, # 矢量文字更兼容
'ps.fonttype': 42
})
# 统一 seaborn 样式(paper 语境 + 白底网格 + 稳定配色)
sns.set_theme(
context='paper',
style='whitegrid',
palette='deep'
)
# 若你之前已经定义 FIG_DIR/OUT_DIR,这里自动沿用;否则创建默认目录
FIG_DIR = globals().get('FIG_DIR', 'figs')
os.makedirs(FIG_DIR, exist_ok=True)
def save_fig(fig, name, dpi=600):
"""
论文导出:PNG+PDF,默认 600 dpi。
使用:save_fig(plt.gcf(), "figure_name")
"""
png_path = os.path.join(FIG_DIR, f"{name}.png")
pdf_path = os.path.join(FIG_DIR, f"{name}.pdf")
fig.savefig(png_path, dpi=dpi, bbox_inches='tight')
fig.savefig(pdf_path, dpi=dpi, bbox_inches='tight')
print(f"[Saved] {png_path} & {pdf_path}")
print("Academic figure style applied. All figures will use journal-style fonts/sizes/line widths and 600dpi export.")
Academic figure style applied. All figures will use journal-style fonts/sizes/line widths and 600dpi export.
In [16]:
# === 2. 数据概览与缺失热力图 ===
def overview(df, name='df'):
print(f'[{name}] shape = {df.shape}')
display(df.head(3))
miss = df.isna().sum().to_frame('n_missing').query('n_missing>0')
print('Missing values:')
display(miss)
overview(train, 'train')
overview(test, 'test')
# 缺失热力图(train/test)
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
sns.heatmap(train.isna(), cbar=False, ax=axes[0])
axes[0].set_title('Missingness - train')
sns.heatmap(test.isna(), cbar=False, ax=axes[1])
axes[1].set_title('Missingness - test')
plt.tight_layout()
save_fig(plt.gcf(), "EDA_missingness_heatmap")
plt.show()
[train] shape = (891, 13)
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | is_train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 1 |
| 1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 |
| 2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 1 |
Missing values:
| n_missing | |
|---|---|
| Age | 177 |
| Cabin | 687 |
| Embarked | 2 |
[test] shape = (418, 13)
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | is_train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 892 | 0 | 3 | Kelly, Mr. James | male | 34.5 | 0 | 0 | 330911 | 7.8292 | NaN | Q | 0 |
| 1 | 893 | 1 | 3 | Wilkes, Mrs. James (Ellen Needs) | female | 47.0 | 1 | 0 | 363272 | 7.0000 | NaN | S | 0 |
| 2 | 894 | 0 | 2 | Myles, Mr. Thomas Francis | male | 62.0 | 0 | 0 | 240276 | 9.6875 | NaN | Q | 0 |
Missing values:
| n_missing | |
|---|---|
| Age | 86 |
| Fare | 1 |
| Cabin | 327 |
[Saved] figs\EDA_missingness_heatmap.png & figs\EDA_missingness_heatmap.pdf
In [17]:
# === 3. 合并数据与缓存标签 ===
train['is_train'] = 1
test['is_train'] = 0
full = pd.concat([train, test], ignore_index=True)
# 缓存标签与ID(不改变后续结果)
y = full.loc[full['is_train']==1, 'Survived'].astype('int64')
pid_train = full.loc[full['is_train']==1, 'PassengerId'].copy()
pid_test = full.loc[full['is_train']==0, 'PassengerId'].copy()
display(full.head())
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | is_train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 1 |
| 1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 |
| 2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 1 |
| 3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 1 |
| 4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 1 |
In [18]:
# === 4. 特征工程 ===
full['Name'] = full['Name'].astype(str)
# Title
full['Title'] = full['Name'].str.extract(r',\s*([^\.]+)\.', expand=False).str.strip()
rare_map = {
'Mlle':'Miss','Ms':'Miss','Mme':'Mrs',
'Lady':'Rare','Countess':'Rare','Capt':'Rare','Col':'Rare',
'Don':'Rare','Dr':'Rare','Major':'Rare','Rev':'Rare',
'Sir':'Rare','Jonkheer':'Rare','Dona':'Rare'
}
full['Title'] = full['Title'].replace(rare_map).fillna('Unknown')
# 家庭特征
full['FamilySize'] = (full['SibSp'].fillna(0) + full['Parch'].fillna(0) + 1).astype(int)
full['IsAlone'] = (full['FamilySize'] == 1).astype(int)
# Deck
full['Cabin'] = full['Cabin'].astype(str)
full['Deck'] = full['Cabin'].str[0]
full.loc[full['Deck'].isna() | (full['Deck']=='n') | (full['Deck']==''), 'Deck'] = 'U'
# Ticket 前缀
def extract_ticket_prefix(t):
t = str(t).replace('.', '').replace('/', '').strip().upper()
parts = [p for p in t.split() if not p.isdigit()]
return parts[0] if len(parts)>0 else 'NONE'
full['TicketPrefix'] = full['Ticket'].apply(extract_ticket_prefix)
# 同票号人数
full['TicketGroupSize'] = full.groupby('Ticket')['Ticket'].transform('count').astype(int)
In [19]:
# === 5. 缺失值填充(不改变既有策略) ===
# Embarked:众数
embarked_mode = full['Embarked'].mode(dropna=True)[0]
full['Embarked'] = full['Embarked'].fillna(embarked_mode)
# Fare:(Pclass, Embarked) 组中位数,兜底全局中位数
fare_group_median = full.groupby(['Pclass','Embarked'])['Fare'].transform('median')
full['Fare'] = full['Fare'].fillna(fare_group_median)
full['Fare'] = full['Fare'].fillna(full['Fare'].median())
# Age:(Title, Pclass, Sex) 组中位数,兜底 Pclass/全局中位数
age_group_median = full.groupby(['Title','Pclass','Sex'])['Age'].transform('median')
full['Age'] = full['Age'].fillna(age_group_median)
full['Age'] = full['Age'].fillna(full.groupby('Pclass')['Age'].transform('median'))
full['Age'] = full['Age'].fillna(full['Age'].median())
# Fare 对数
full['Fare_log1p'] = np.log1p(full['Fare'])
# 缺失检查
miss = full.isna().sum()
display(miss[miss>0].to_frame('n_missing'))
| n_missing |
|---|
In [20]:
# === 6. 编码与拆分 ===
cat_cols = ['Sex','Embarked','Title','Deck','TicketPrefix']
for c in cat_cols:
full[c] = full[c].astype('category')
num_cols = ['Pclass','Age','SibSp','Parch','Fare','FamilySize','IsAlone','TicketGroupSize','Fare_log1p']
id_cols = ['PassengerId','is_train','Survived']
full_encoded = pd.get_dummies(full[id_cols + cat_cols + num_cols], columns=cat_cols, drop_first=False)
train_processed = full_encoded[full_encoded['is_train']==1].drop(columns=['is_train'])
test_processed = full_encoded[full_encoded['is_train']==0].drop(columns=['is_train','Survived'])
X = train_processed.drop(columns=['PassengerId','Survived'])
y = y.loc[train_processed.index]
X_test = test_processed.drop(columns=['PassengerId'])
print("Shapes:", X.shape, X_test.shape)
display(X.head())
Shapes: (891, 65) (418, 65)
| Pclass | Age | SibSp | Parch | Fare | FamilySize | IsAlone | TicketGroupSize | Fare_log1p | Sex_female | Sex_male | Embarked_C | Embarked_Q | Embarked_S | Title_Master | Title_Miss | Title_Mr | Title_Mrs | Title_Rare | Title_the Countess | Deck_A | Deck_B | Deck_C | Deck_D | Deck_E | Deck_F | Deck_G | Deck_T | Deck_U | TicketPrefix_A | TicketPrefix_A4 | TicketPrefix_A5 | TicketPrefix_AQ3 | TicketPrefix_AQ4 | TicketPrefix_AS | TicketPrefix_C | TicketPrefix_CA | TicketPrefix_CASOTON | TicketPrefix_FA | TicketPrefix_FC | TicketPrefix_FCC | TicketPrefix_LINE | TicketPrefix_LP | TicketPrefix_NONE | TicketPrefix_PC | TicketPrefix_PP | TicketPrefix_PPP | TicketPrefix_SC | TicketPrefix_SCA3 | TicketPrefix_SCA4 | TicketPrefix_SCAH | TicketPrefix_SCOW | TicketPrefix_SCPARIS | TicketPrefix_SOC | TicketPrefix_SOP | TicketPrefix_SOPP | TicketPrefix_SOTONO2 | TicketPrefix_SOTONOQ | TicketPrefix_SP | TicketPrefix_STONO | TicketPrefix_STONO2 | TicketPrefix_STONOQ | TicketPrefix_SWPP | TicketPrefix_WC | TicketPrefix_WEP | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 3 | 22.0 | 1 | 0 | 7.2500 | 2 | 0 | 1 | 2.110213 | False | True | False | False | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 1 | 1 | 38.0 | 1 | 0 | 71.2833 | 2 | 0 | 2 | 4.280593 | True | False | True | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 2 | 3 | 26.0 | 0 | 0 | 7.9250 | 1 | 1 | 1 | 2.188856 | True | False | False | False | True | False | True | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False |
| 3 | 1 | 35.0 | 1 | 0 | 53.1000 | 2 | 0 | 2 | 3.990834 | True | False | False | False | True | False | False | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| 4 | 3 | 35.0 | 0 | 0 | 8.0500 | 1 | 1 | 1 | 2.202765 | False | True | False | False | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
In [21]:
# === 7. EDA:类别生存率 ===
eda_df = full.loc[full['is_train']==1, ['Survived','Sex','Pclass','Embarked','Title','Age','Fare','FamilySize','IsAlone']].copy()
eda_df['Survived'] = eda_df['Survived'].astype(int)
def rate_plot(group_col, data=eda_df):
rate = data.groupby(group_col)['Survived'].mean().sort_values(ascending=False)
fig, ax = plt.subplots()
sns.barplot(x=rate.index.astype(str), y=rate.values, ax=ax)
ax.set_title(f'Survival Rate by {group_col}')
ax.set_ylabel('Survival Rate'); ax.set_xlabel(group_col)
plt.xticks(rotation=30); ax.set_ylim(0,1)
plt.tight_layout()
save_fig(fig, f"EDA_survival_rate_by_{group_col}")
plt.show()
for col in ['Sex','Pclass','Embarked','Title','IsAlone','FamilySize']:
rate_plot(col)
# === 7. EDA:KDE(Age / Fare) ===
fig, ax = plt.subplots()
sns.kdeplot(data=eda_df, x='Age', hue='Survived', common_norm=False, ax=ax)
ax.set_title('Age Distribution by Survival')
plt.tight_layout(); save_fig(fig, "EDA_kde_age_by_survival"); plt.show()
fig, ax = plt.subplots()
sns.kdeplot(data=eda_df, x='Fare', hue='Survived', common_norm=False, ax=ax)
ax.set_title('Fare Distribution by Survival')
plt.tight_layout(); save_fig(fig, "EDA_kde_fare_by_survival"); plt.show()
[Saved] figs\EDA_survival_rate_by_Sex.png & figs\EDA_survival_rate_by_Sex.pdf
[Saved] figs\EDA_survival_rate_by_Pclass.png & figs\EDA_survival_rate_by_Pclass.pdf
[Saved] figs\EDA_survival_rate_by_Embarked.png & figs\EDA_survival_rate_by_Embarked.pdf
[Saved] figs\EDA_survival_rate_by_Title.png & figs\EDA_survival_rate_by_Title.pdf
[Saved] figs\EDA_survival_rate_by_IsAlone.png & figs\EDA_survival_rate_by_IsAlone.pdf
[Saved] figs\EDA_survival_rate_by_FamilySize.png & figs\EDA_survival_rate_by_FamilySize.pdf
[Saved] figs\EDA_kde_age_by_survival.png & figs\EDA_kde_age_by_survival.pdf
[Saved] figs\EDA_kde_fare_by_survival.png & figs\EDA_kde_fare_by_survival.pdf
In [22]:
# === 8A. 数值特征分布网格 ===
X_num = X.select_dtypes(include=[np.number]).copy()
cols = X_num.columns.tolist()
n_show = min(21, len(cols))
grid_rows, grid_cols = 7, 3
fig = plt.figure(figsize=(15, 18))
for i in range(n_show):
ax = plt.subplot(grid_rows, grid_cols, i+1)
col = cols[i]
sns.histplot(X_num[col].dropna(), kde=True, ax=ax)
ax.set_title(col); ax.set_xlabel('')
plt.tight_layout()
save_fig(fig, "EDA_numeric_hist_grid")
plt.show()
# === 8B. Pearson ===
corr_pearson = X_num.corr(method='pearson')
mask = np.triu(np.ones_like(corr_pearson, dtype=bool))
fig, ax = plt.subplots(figsize=(12,10))
sns.heatmap(corr_pearson, mask=mask, cmap='coolwarm', center=0, vmax=1, vmin=-1,
square=True, linewidths=.5, cbar_kws={"shrink": .8}, ax=ax)
ax.set_title('Pearson Correlation (features)')
plt.tight_layout(); save_fig(fig, "EDA_corr_pearson")
plt.show()
# === 8C. Spearman ===
corr_spearman = X_num.corr(method='spearman')
mask = np.triu(np.ones_like(corr_spearman, dtype=bool))
fig, ax = plt.subplots(figsize=(12,10))
sns.heatmap(corr_spearman, mask=mask, cmap='coolwarm', center=0, vmax=1, vmin=-1,
square=True, linewidths=.5, cbar_kws={"shrink": .8}, ax=ax)
ax.set_title('Spearman Correlation (features)')
plt.tight_layout(); save_fig(fig, "EDA_corr_spearman")
plt.show()
# === 8D. 聚类热图(返回 ClusterGrid,单独保存) ===
cg = sns.clustermap(corr_pearson, cmap='coolwarm', center=0, figsize=(12, 12))
plt.suptitle('Clustered Heatmap (Pearson)', y=1.02)
cg.fig.subplots_adjust(top=0.92)
cg.savefig(os.path.join(FIG_DIR, "EDA_corr_cluster_pearson.png"), dpi=300)
cg.savefig(os.path.join(FIG_DIR, "EDA_corr_cluster_pearson.pdf"), dpi=300)
print("[Saved] figs/EDA_corr_cluster_pearson.[png|pdf]")
plt.show()
# === 8E. 强相关对(导出 CSV) ===
def top_corr_pairs(corr_mat, threshold=0.85):
c = corr_mat.copy()
c.values[np.tril_indices_from(c)] = np.nan
pairs = (
c.stack().rename('corr').reset_index()
.rename(columns={'level_0':'feature_1','level_1':'feature_2'})
.assign(abs_corr=lambda df: df['corr'].abs())
.sort_values('abs_corr', ascending=False)
)
return pairs.query('abs_corr >= @threshold')
top_pairs = top_corr_pairs(corr_pearson, threshold=0.85)
display(top_pairs.head(20))
top_pairs_path = os.path.join(OUT_DIR, "EDA_top_corr_pairs_pearson.csv")
top_pairs.to_csv(top_pairs_path, index=False)
print(f"[Saved] {top_pairs_path}")
[Saved] figs\EDA_numeric_hist_grid.png & figs\EDA_numeric_hist_grid.pdf
[Saved] figs\EDA_corr_pearson.png & figs\EDA_corr_pearson.pdf
[Saved] figs\EDA_corr_spearman.png & figs\EDA_corr_spearman.pdf
[Saved] figs/EDA_corr_cluster_pearson.[png|pdf]
| feature_1 | feature_2 | corr | abs_corr | |
|---|---|---|---|---|
| 17 | SibSp | FamilySize | 0.890712 | 0.890712 |
[Saved] outputs\EDA_top_corr_pairs_pearson.csv
In [23]:
# === 9. 切分与基线(Logit / RF) ===
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
X_train, X_valid, y_train, y_valid = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
logit_clf = Pipeline([
('scaler', StandardScaler(with_mean=False)),
('clf', LogisticRegression(max_iter=200))
])
logit_clf.fit(X_train, y_train)
print('LogReg valid acc:', logit_clf.score(X_valid, y_valid))
rf_clf = RandomForestClassifier(n_estimators=300, random_state=42, n_jobs=-1)
rf_clf.fit(X_train, y_train)
print('RF valid acc:', rf_clf.score(X_valid, y_valid))
LogReg valid acc: 0.7988826815642458 RF valid acc: 0.770949720670391
In [24]:
# === 10. 决策树:评估与图导出 ===
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, roc_curve, confusion_matrix
def eval_and_plot_decision_tree(dt_model, X_tr, y_tr, X_te, y_te, title_prefix="Decision Tree", save_prefix="DT"):
dt_model.fit(X_tr, y_tr)
y_pred = dt_model.predict(X_te)
y_prob = dt_model.predict_proba(X_te)[:, 1] if hasattr(dt_model, "predict_proba") else None
# 指标
prec = precision_score(y_te, y_pred)
rec = recall_score(y_te, y_pred)
f1 = f1_score(y_te, y_pred)
auc_prob = roc_auc_score(y_te, y_prob) if y_prob is not None else None
auc_lbl = roc_auc_score(y_te, y_pred)
print(f"\n===== {title_prefix} =====")
print(f"Precision: {prec:.3f} Recall: {rec:.3f} F1: {f1:.3f}")
if auc_prob is not None:
print(f"ROC AUC (prob): {auc_prob:.3f} | ROC AUC (label): {auc_lbl:.3f}")
else:
print(f"ROC AUC (label): {auc_lbl:.3f} (no predict_proba)")
# 混淆矩阵
cm = confusion_matrix(y_te, y_pred)
fig, ax = plt.subplots(figsize=(4,3))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_title(f"{title_prefix} - Confusion Matrix")
ax.set_xlabel("Predicted"); ax.set_ylabel("Actual")
plt.tight_layout(); save_fig(fig, f"{save_prefix}_confusion_matrix"); plt.show()
# ROC
if y_prob is not None:
fpr, tpr, _ = roc_curve(y_te, y_prob)
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(fpr, tpr, label=f"AUC={roc_auc_score(y_te, y_prob):.3f}")
ax.plot([0,1],[0,1],'--', linewidth=1)
ax.set_title(f"{title_prefix} - ROC Curve")
ax.set_xlabel("FPR"); ax.set_ylabel("TPR")
ax.legend(); plt.tight_layout(); save_fig(fig, f"{save_prefix}_roc"); plt.show()
# 树可视化(Graphviz 优先)
try:
from sklearn import tree
import graphviz
dot_data = tree.export_graphviz(
dt_model,
feature_names=list(X_tr.columns),
class_names=['Not Survived','Survived'],
filled=True, rounded=True, proportion=False, out_file=None
)
graph = graphviz.Source(dot_data)
graph_path = os.path.join(FIG_DIR, f"{save_prefix}_graph")
graph.render(graph_path, format='pdf', cleanup=True)
graph.render(graph_path, format='png', cleanup=True)
print(f"[Saved] {graph_path}.pdf & {graph_path}.png")
display(graph)
except Exception:
fig = plt.figure(figsize=(16,10))
plot_tree(dt_model, feature_names=list(X_tr.columns), class_names=['Not Survived','Survived'],
filled=True, rounded=True, impurity=True, fontsize=8)
plt.title(f"{title_prefix} - Tree (matplotlib)")
plt.tight_layout(); save_fig(fig, f"{save_prefix}_tree_matplotlib"); plt.show()
# 特征重要性
importances = pd.Series(dt_model.feature_importances_, index=X_tr.columns)
top_imp = importances.sort_values(ascending=False).head(15)
fig, ax = plt.subplots(figsize=(8,5))
sns.barplot(x=top_imp.values, y=top_imp.index, ax=ax)
ax.set_title(f"{title_prefix} - Top 15 Feature Importances")
ax.set_xlabel("Importance"); ax.set_ylabel("Feature")
plt.tight_layout(); save_fig(fig, f"{save_prefix}_feature_importance_top15"); plt.show()
return y_pred, y_prob
# 评估(默认 + 限深)
dt_default = DecisionTreeClassifier(random_state=42)
_ = eval_and_plot_decision_tree(dt_default, X_train, y_train, X_valid, y_valid,
title_prefix="Decision Tree (Default)", save_prefix="DT_default")
dt_depth4 = DecisionTreeClassifier(max_depth=4, random_state=42)
_ = eval_and_plot_decision_tree(dt_depth4, X_train, y_train, X_valid, y_valid,
title_prefix="Decision Tree (max_depth=4)", save_prefix="DT_depth4")
# 测试集提交
dt_final = DecisionTreeClassifier(max_depth=4, random_state=42)
dt_final.fit(X, y)
test_pred_dt = dt_final.predict(X_test)
sub_dt = pd.DataFrame({'PassengerId': test_processed['PassengerId'], 'Survived': test_pred_dt.astype(int)})
sub_dt.to_csv(os.path.join(OUT_DIR, 'submission_decision_tree_depth4.csv'), index=False)
print("[Saved] outputs/submission_decision_tree_depth4.csv")
display(sub_dt.head(10))
===== Decision Tree (Default) ===== Precision: 0.729 Recall: 0.739 F1: 0.734 ROC AUC (prob): 0.778 | ROC AUC (label): 0.783 [Saved] figs\DT_default_confusion_matrix.png & figs\DT_default_confusion_matrix.pdf
[Saved] figs\DT_default_roc.png & figs\DT_default_roc.pdf
[Saved] figs\DT_default_tree_matplotlib.png & figs\DT_default_tree_matplotlib.pdf
[Saved] figs\DT_default_feature_importance_top15.png & figs\DT_default_feature_importance_top15.pdf
===== Decision Tree (max_depth=4) ===== Precision: 0.811 Recall: 0.623 F1: 0.705 ROC AUC (prob): 0.831 | ROC AUC (label): 0.766 [Saved] figs\DT_depth4_confusion_matrix.png & figs\DT_depth4_confusion_matrix.pdf
[Saved] figs\DT_depth4_roc.png & figs\DT_depth4_roc.pdf
[Saved] figs\DT_depth4_tree_matplotlib.png & figs\DT_depth4_tree_matplotlib.pdf
[Saved] figs\DT_depth4_feature_importance_top15.png & figs\DT_depth4_feature_importance_top15.pdf
[Saved] outputs/submission_decision_tree_depth4.csv
| PassengerId | Survived | |
|---|---|---|
| 891 | 892 | 0 |
| 892 | 893 | 1 |
| 893 | 894 | 0 |
| 894 | 895 | 0 |
| 895 | 896 | 1 |
| 896 | 897 | 0 |
| 897 | 898 | 1 |
| 898 | 899 | 0 |
| 899 | 900 | 1 |
| 900 | 901 | 0 |
In [25]:
# === 11. NB & LR:评估与图导出 ===
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.model_selection import learning_curve
def eval_and_plot_classifier(model, name, X_tr, y_tr, X_te, y_te, feature_names=None, show_coeff=False, save_prefix=""):
model.fit(X_tr, y_tr)
y_pred = model.predict(X_te)
y_prob = model.predict_proba(X_te)[:, 1] if hasattr(model, "predict_proba") else None
# 指标
prec = precision_score(y_te, y_pred)
rec = recall_score(y_te, y_pred)
f1 = f1_score(y_te, y_pred)
auc_prob = roc_auc_score(y_te, y_prob) if y_prob is not None else None
auc_lbl = roc_auc_score(y_te, y_pred)
print(f"\n===== {name} =====")
print(f"Precision: {prec:.3f} Recall: {rec:.3f} F1: {f1:.3f}")
if y_prob is not None:
print(f"ROC AUC (prob): {auc_prob:.3f} | ROC AUC (label): {auc_lbl:.3f}")
else:
print(f"ROC AUC (label): {auc_lbl:.3f} (no predict_proba)")
# Confusion Matrix
cm = confusion_matrix(y_te, y_pred)
fig, ax = plt.subplots(figsize=(4,3))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_title(f"{name} - Confusion Matrix")
ax.set_xlabel("Predicted"); ax.set_ylabel("Actual")
plt.tight_layout(); save_fig(fig, f"{save_prefix}_confusion_matrix"); plt.show()
# ROC & PR(若可用)
if y_prob is not None:
fpr, tpr, _ = roc_curve(y_te, y_prob)
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(fpr, tpr, label=f"AUC={roc_auc_score(y_te, y_prob):.3f}")
ax.plot([0,1],[0,1],'--', linewidth=1)
ax.set_title(f"{name} - ROC Curve")
ax.set_xlabel("FPR"); ax.set_ylabel("TPR")
ax.legend(); plt.tight_layout(); save_fig(fig, f"{save_prefix}_roc"); plt.show()
precision, recall, _ = precision_recall_curve(y_te, y_prob)
ap = average_precision_score(y_te, y_prob)
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(recall, precision, label=f"AP={ap:.3f}")
ax.set_title(f"{name} - Precision-Recall Curve")
ax.set_xlabel("Recall"); ax.set_ylabel("Precision")
ax.legend(); plt.tight_layout(); save_fig(fig, f"{save_prefix}_pr"); plt.show()
# 学习曲线
train_sizes, train_scores, val_scores = learning_curve(
model, X_tr, y_tr, cv=5, scoring='accuracy',
train_sizes=np.linspace(0.2, 1.0, 5), n_jobs=-1, shuffle=True, random_state=42
)
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(train_sizes, train_scores.mean(axis=1), marker='o', label='Training')
ax.plot(train_sizes, val_scores.mean(axis=1), marker='o', label='Validation')
ax.set_title(f"{name} - Learning Curve")
ax.set_xlabel("Training Samples"); ax.set_ylabel("Accuracy")
ax.legend(); plt.tight_layout(); save_fig(fig, f"{save_prefix}_learning_curve"); plt.show()
# 逻辑回归系数(可选)
if show_coeff and feature_names is not None and hasattr(model, "named_steps"):
lr = model.named_steps.get('clf', None)
if lr is not None and hasattr(lr, "coef_"):
coefs = pd.Series(lr.coef_.ravel(), index=feature_names).sort_values()
top_pos = coefs.tail(12)
top_neg = coefs.head(12)
coef_plot = pd.concat([top_neg, top_pos])
fig, ax = plt.subplots(figsize=(8,6))
sns.barplot(x=coef_plot.values, y=coef_plot.index, palette="vlag", ax=ax)
ax.axvline(0, color='k', linewidth=1)
ax.set_title(f"{name} - Top ± Coefficients")
ax.set_xlabel("Coefficient"); ax.set_ylabel("Feature")
plt.tight_layout(); save_fig(fig, f"{save_prefix}_top_coefficients"); plt.show()
return y_pred, y_prob, model
# 高斯朴素贝叶斯
nb = GaussianNB()
_ = eval_and_plot_classifier(nb, "Naive Bayes (GaussianNB)",
X_train, y_train, X_valid, y_valid,
save_prefix="NB")
nb.fit(X, y)
test_pred_nb = nb.predict(X_test)
sub_nb = pd.DataFrame({'PassengerId': test_processed['PassengerId'], 'Survived': test_pred_nb.astype(int)})
sub_nb.to_csv(os.path.join(OUT_DIR, 'submission_naive_bayes.csv'), index=False)
print("[Saved] outputs/submission_naive_bayes.csv")
display(sub_nb.head(10))
# 逻辑回归(标准化 Pipeline)
logit = Pipeline([
('scaler', StandardScaler(with_mean=False)),
('clf', LogisticRegression(max_iter=500, solver='lbfgs', random_state=42))
])
_ = eval_and_plot_classifier(logit, "Logistic Regression",
X_train, y_train, X_valid, y_valid,
feature_names=X.columns, show_coeff=True,
save_prefix="LR")
logit.fit(X, y)
test_pred_lr = logit.predict(X_test)
sub_lr = pd.DataFrame({'PassengerId': test_processed['PassengerId'], 'Survived': test_pred_lr.astype(int)})
sub_lr.to_csv(os.path.join(OUT_DIR, 'submission_logistic_regression.csv'), index=False)
print("[Saved] outputs/submission_logistic_regression.csv")
display(sub_lr.head(10))
===== Naive Bayes (GaussianNB) ===== Precision: 0.392 Recall: 0.942 F1: 0.553 ROC AUC (prob): 0.741 | ROC AUC (label): 0.512 [Saved] figs\NB_confusion_matrix.png & figs\NB_confusion_matrix.pdf
[Saved] figs\NB_roc.png & figs\NB_roc.pdf
[Saved] figs\NB_pr.png & figs\NB_pr.pdf
[Saved] figs\NB_learning_curve.png & figs\NB_learning_curve.pdf
[Saved] outputs/submission_naive_bayes.csv
| PassengerId | Survived | |
|---|---|---|
| 891 | 892 | 1 |
| 892 | 893 | 1 |
| 893 | 894 | 1 |
| 894 | 895 | 1 |
| 895 | 896 | 1 |
| 896 | 897 | 1 |
| 897 | 898 | 1 |
| 898 | 899 | 1 |
| 899 | 900 | 1 |
| 900 | 901 | 0 |
===== Logistic Regression ===== Precision: 0.732 Recall: 0.754 F1: 0.743 ROC AUC (prob): 0.846 | ROC AUC (label): 0.790 [Saved] figs\LR_confusion_matrix.png & figs\LR_confusion_matrix.pdf
[Saved] figs\LR_roc.png & figs\LR_roc.pdf
[Saved] figs\LR_pr.png & figs\LR_pr.pdf
[Saved] figs\LR_learning_curve.png & figs\LR_learning_curve.pdf
[Saved] figs\LR_top_coefficients.png & figs\LR_top_coefficients.pdf
[Saved] outputs/submission_logistic_regression.csv
| PassengerId | Survived | |
|---|---|---|
| 891 | 892 | 0 |
| 892 | 893 | 0 |
| 893 | 894 | 0 |
| 894 | 895 | 0 |
| 895 | 896 | 1 |
| 896 | 897 | 0 |
| 897 | 898 | 1 |
| 898 | 899 | 0 |
| 899 | 900 | 1 |
| 900 | 901 | 0 |
In [26]:
# === 12. 模型对比:表格与图导出 ===
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
models_summary = {
"Decision Tree (max_depth=4)": dt_depth4,
"Naive Bayes": nb,
"Logistic Regression": logit
}
rows = []
for name, mdl in models_summary.items():
y_pred = mdl.predict(X_valid)
y_prob = mdl.predict_proba(X_valid)[:, 1] if hasattr(mdl, "predict_proba") else None
acc = accuracy_score(y_valid, y_pred)
prec = precision_score(y_valid, y_pred)
rec = recall_score(y_valid, y_pred)
f1 = f1_score(y_valid, y_pred)
auc = roc_auc_score(y_valid, y_prob) if y_prob is not None else roc_auc_score(y_valid, y_pred)
rows.append({"Model": name, "Accuracy": acc, "Precision": prec, "Recall": rec, "F1 Score": f1, "ROC AUC": auc})
compare_df = pd.DataFrame(rows).sort_values("ROC AUC", ascending=False).reset_index(drop=True)
display(compare_df.round(4))
# 导出对比表
compare_path = os.path.join(OUT_DIR, "model_compare_validation.csv")
compare_df.to_csv(compare_path, index=False)
print(f"[Saved] {compare_path}")
# 水平条形图
fig, ax = plt.subplots(figsize=(9,5))
metrics = ["Accuracy", "Precision", "Recall", "F1 Score", "ROC AUC"]
for metric in metrics:
ax.barh(compare_df["Model"], compare_df[metric], alpha=0.7, label=metric)
ax.set_title("Model Performance Comparison (Validation Set)")
ax.set_xlabel("Score"); ax.set_xlim(0,1); ax.legend(loc="lower right")
plt.tight_layout(); save_fig(fig, "Model_compare_horizontal_bars"); plt.show()
# 分组条形图
melt_df = compare_df.melt(id_vars="Model", var_name="Metric", value_name="Score")
fig, ax = plt.subplots(figsize=(10,6))
sns.barplot(data=melt_df, x="Metric", y="Score", hue="Model", palette="pastel", ax=ax)
ax.set_title("Comparison of Models Across Evaluation Metrics")
ax.set_ylim(0,1); ax.legend(title="Model", loc="lower right")
plt.tight_layout(); save_fig(fig, "Model_compare_grouped_bars"); plt.show()
# 输出结论(文本)
best_model = compare_df.iloc[0]
print("Best model:", best_model['Model'])
print("Acc={:.4f}, P={:.4f}, R={:.4f}, F1={:.4f}, AUC={:.4f}".format(
best_model['Accuracy'], best_model['Precision'], best_model['Recall'],
best_model['F1 Score'], best_model['ROC AUC']
))
| Model | Accuracy | Precision | Recall | F1 Score | ROC AUC | |
|---|---|---|---|---|---|---|
| 0 | Logistic Regression | 0.8603 | 0.8235 | 0.8116 | 0.8175 | 0.9080 |
| 1 | Decision Tree (max_depth=4) | 0.7989 | 0.8113 | 0.6232 | 0.7049 | 0.8312 |
| 2 | Naive Bayes | 0.4413 | 0.4061 | 0.9710 | 0.5726 | 0.8145 |
[Saved] outputs\model_compare_validation.csv [Saved] figs\Model_compare_horizontal_bars.png & figs\Model_compare_horizontal_bars.pdf
[Saved] figs\Model_compare_grouped_bars.png & figs\Model_compare_grouped_bars.pdf
Best model: Logistic Regression Acc=0.8603, P=0.8235, R=0.8116, F1=0.8175, AUC=0.9080