Tutorial on Machine Learning Classification Prediction of Breast Cancer Diagnosis Dataset
1. Tutorial Overview
Tool/Software Introduction
- Core Tools: R language (statistical computing and machine learning platform) and key libraries
- Data processing:
tidyverse
(Includedplyr
Data manipulation,ggplot2
visualization, etc.) - Modeling tools:
caret
(Unified Modeling Process),glmnet
(Regularized Logistic Regression),ranger
(Efficient Random Forest),class
(KNN basic algorithm support) - Visualization Tools:
corrplot
(Correlation heat map),GGally
(Multivariate Data Visualization)
- Data processing:
- Development Background:Based on the Wisconsin Breast Cancer Diagnosis Dataset (WDBC), this paper fully demonstrates the entire machine learning process for binary classification problems (from data preprocessing to model deployment and evaluation).
Learning Objectives
- Master the standardized preprocessing process of medical datasets (including redundant feature removal, high correlation feature screening, data splitting and standardization).
- Be proficient in building three classic classification models: regularized logistic regression (glmnet), random forest (including efficient ranger implementation), and KNN, and understand their applicable scenarios.
- Generate and interpret core visualization results such as feature correlation heatmaps, principal component variance contribution plots, model performance comparison plots, and feature importance ranking plots.
- Identify key features of breast cancer diagnosis (such as
area_worst
,perimeter_worst
) and its clinical significance.
2. Some preliminary preparations (can be cloned through the tutorial released by openbayes)
Core file description
- Dataset:
data.csv
(569 samples, 30 nuclear features, labeled as benign (B)/malignant (M), includingid
(sample number),diagnosis
(diagnostic results) and 30 morphological features). - Code files: Contains a complete R script for data loading → preprocessing → EDA → feature engineering → modeling → evaluation (can be run directly, the data path needs to be replaced).
- Dependency library list(Installation command):
install.packages(c("tidyverse", "caret", "ranger", "corrplot", "GGally", "glmnet"))
Key technical processes
- Data preprocessing: Load data → remove redundant columns (such as
X
,id
) → Label type conversion (factorization) → Missing value check. - Feature Engineering: Calculate the feature correlation matrix → Screen highly correlated features (threshold 0.9) → Principal component analysis (PCA) dimensionality reduction and visualization.
- Modeling process:Dataset split (8:2 training set/test set) → 10-fold cross-validation parameter tuning → Parallel training of three models → Confusion matrix and performance indicator evaluation.
- Evaluation indicators:Accuracy, area under the ROC curve (AUC), sensitivity (Sensitivity, malignant detection rate), specificity (Specificity, benign detection rate).
3. Practical operation steps
Environment Setup
- Install R (Download from the official website) and RStudio (recommended IDE,Download from the official website).
- Install dependent libraries (see "Dependency library list" above).
Data loading and preprocessing
# 加载库 library(tidyverse) # 数据处理核心库 library(caret) # 建模与评价工具 加载数据(替换为实际路径) data <- read.csv("path/to/data.csv") 预处理:移除冗余列(X 为索引列,id 为样本编号,均非特征) data <- data %>% select(-X, -id) 标签转换:将 diagnosis(B/M)转为因子类型(分类模型要求) data$diagnosis <- as.factor(data$diagnosis) 查看数据结构(确认特征类型与样本量) str(data) 检查缺失值(该数据集无缺失,实际场景可补充插补步骤)
colSums(is.na(data))
Exploratory Data Analysis (EDA)
Core Goals: Understand feature distribution and correlation, and identify multicollinearity problems.
library(corrplot) # 相关性热图工具 计算特征相关性矩阵(排除标签列) data_corr <- cor(data %>% select(-diagnosis)) 绘制聚类排序的相关性热图(便于识别高相关特征群)
corrplot(data_corr, order = "hclust", tl.cex = 0.6, addrect = 8)
Interpretation of the results:
- radius(
radius
),perimeter(perimeter
),area(area
) and other features have a correlation greater than 0.9, indicating serious multicollinearity and requiring feature screening to remove redundancy.
Principal Component Analysis (PCA)
Core Goals:Simplify the data structure by reducing dimensionality, retain key information and visualize it.
library(GGally) # 多元可视化工具 步骤 1:移除高相关特征(阈值 0.9) high_corr_indices <- findCorrelation(data_corr, cutoff = 0.9) # 筛选高相关特征索引
data2 <- data %>% select(-all_of(names(data)[high_corr_indices])) # 移除冗余特征 步骤 2:执行 PCA(需标准化与中心化) pca_data2 <- prcomp(data2, scale = TRUE, center = TRUE) 步骤 3:可视化方差贡献(确定核心主成分) explained_variance <- pca_data2$sdev^2 / sum(pca_data2$sdev^2) # 单个主成分方差占比
cumulative_variance <- cumsum(explained_variance) # 累积方差占比 variance_data <- data.frame(
PC = 1:length(explained_variance),
ExplainedVariance = explained_variance,
CumulativeVariance = cumulative_variance
)ggplot(variance_data, aes(x = PC)) +
geom_bar(aes(y = ExplainedVariance), stat = "identity", fill = "skyblue", alpha = 0.7) +
geom_line(aes(y = CumulativeVariance), color = "red", size = 1) +
geom_point(aes(y = CumulativeVariance), color = "red") +
labs(
title = "主成分方差贡献图",
x = "主成分",
y = "方差解释比例"
) +
scale_y_continuous(sec.axis = sec_axis(~., name = "累积方差解释比例")) +
theme_minimal()
Interpretation of the results: The cumulative explained variance of the first three principal components is about 70%-80%, which can be used to simplify the model.
Step 4: Visualization of association between principal components and diagnostic labels
# 提取前 3 个主成分得分,关联诊断标签 pca_scores <- as.data.frame(pca_data2$x[, 1:3]) # 前 3 个主成分 pca_scores$diagnosis <- data$diagnosis # 加入诊断标签 绘制散点矩阵(含相关性、密度分布)
ggpairs(
pca_scores,
columns = 1:3,
mapping = aes(color = diagnosis, fill = diagnosis),
upper = list(continuous = wrap("cor", size = 3)), # 上三角:相关性
lower = list(continuous = "points"), # 下三角:散点图
diag = list(continuous = wrap("densityDiag")) # 对角线:密度分布
) +
theme_minimal() +
scale_color_manual(values = c("B" = "salmon", "M" = "cyan3")) +
scale_fill_manual(values = c("B" = "salmon", "M" = "cyan3"))
Interpretation of the results: The distribution of malignant (M) and benign (B) samples on the first three principal components is significantly different, indicating that PCA effectively retains the classification information.
Model Training
Dataset split (8:2)
set.seed(123) # 固定随机种子,保证结果可复现 组合标签与处理后特征(便于建模) data3 <- cbind(diagnosis = data$diagnosis, data2) 按标签分层抽样(维持训练集/测试集类别比例一致) data_sampling_index <- createDataPartition(data3$diagnosis, times = 1, p = 0.8, list = FALSE)
data_training <- data3[data_sampling_index, ] # 训练集(80%)
data_testing <- data3[-data_sampling_index, ] # 测试集(20%) 定义交叉验证策略(10 折交叉验证,计算分类概率与二分类指标)data_control <- trainControl(
method = "cv",
number = 10,
classProbs = TRUE,
summaryFunction = twoClassSummary # 输出 ROC 、灵敏度、特异度
)
1. Regularized Logistic Regression (glmnet)
# 训练模型(带 L1/L2 正则化,自动调优参数)
model_glmnet <- train(
diagnosis ~ .,
data = data_training,
method = "glmnet", # 正则化逻辑回归
metric = "ROC", # 以 ROC 为优化目标
preProcess = c("scale", "center"), # 特征标准化
tuneLength = 20, # 20 组参数候选
trControl = data_control
)
2. Random Forest (two implementations)
Method 1:ranger
Package (efficient implementation)
library(ranger) # 快速随机森林工具
model_rf_ranger <- ranger(
diagnosis ~ .,
data = data_training,
probability = TRUE, # 输出概率
importance = "impurity", # 计算特征重要性(不纯度)
num.trees = 500 # 500 棵决策树
)
Method 2:caret
Package integrated rf
(for cross validation)
model_rf_caret <- train(
diagnosis ~ .,
data = data_training,
method = "rf", # 传统随机森林
metric = "ROC",
trControl = data_control,
ntree = 500 # 500 棵决策树
)
3. KNN (K Nearest Neighbors)
# 训练模型(优化邻居数 k) model_knn <- train( diagnosis ~ ., data = data_training, method = "knn", metric = "ROC", preProcess = c("scale", "center"), # KNN 对距离敏感,必须标准化 trControl = data_control, tuneLength = 31 # 测试 k= 1 到 31 的最优值 ) 可视化不同 k 值的 ROC 表现(确定最优 k)
plot(model_knn, main = "KNN 模型不同邻居数的 ROC 表现")
Model prediction and evaluation
1. Prediction results and confusion matrix
Take logistic regression as an example:
# 测试集预测 prediction_glmnet <- predict(model_glmnet, data_testing) 生成混淆矩阵(评估分类准确性)
cm_glmnet <- confusionMatrix(prediction_glmnet, data_testing$diagnosis, positive = "M")
cm_glmnet # 输出准确率、灵敏度、特异度等指标
Confusion Matrix Visualization:
# 转换为数据框用于绘图 cm_table <- as.table(cm_glmnet$table) cm_df <- as.data.frame(cm_table) colnames(cm_df) <- c("实际标签", "预测标签", "频数")
ggplot(cm_df, aes(x = 实际标签, y = 预测标签, fill = 频数)) +
geom_tile(color = "white") +
scale_fill_gradient(low = "lightblue", high = "blue") +
geom_text(aes(label = 频数), color = "black", size = 6) +
labs(title = "逻辑回归混淆矩阵", x = "实际诊断", y = "预测诊断") +
theme_minimal()
2. Feature Importance Analysis
Random Forest Feature Importance:
# 提取前 10 个重要特征 importance_rf <- model_rf_ranger$variable.importance # ranger 模型结果 importance_df <- data.frame( 特征 = names(importance_rf), 重要性 = importance_rf ) %>% arrange(desc(重要性)) %>% slice(1:10) # 取前 10 可视化
ggplot(importance_df, aes(x = reorder(特征, 重要性), y = 重要性)) +
geom_bar(stat = "identity", fill = "skyblue") +
coord_flip() + # 横向条形图,便于阅读特征名
labs(title = "随机森林 Top10 重要特征", x = "特征", y = "重要性(不纯度下降)") +
theme_minimal()
Key findings:area_worst
(maximum tumor area), perimeter_worst
(The maximum circumference of the tumor) is the core feature for distinguishing benign and malignant tumors, which is consistent with clinical cognition.
IV. Model Comparison and Results Interpretation
Multi-model performance comparison
# 汇总所有模型 model_list <- list( 逻辑回归 = model_glmnet, 随机森林 = model_rf_caret, KNN = model_knn ) 提取交叉验证结果 results <- resamples(model_list) 输出性能指标(ROC 、灵敏度、特异度)
summary(results)
Visual comparison:
# 箱线图:展示各模型 ROC 分布 bwplot(results, metric = "ROC", main = "模型 ROC 性能对比(10 折交叉验证)") 点图:带 95% 置信区间的性能指标
dotplot(results, metric = c("ROC", "Sens", "Spec"), main = "模型性能指标对比")
Interpretation of the results:
- Logistic regression performed best (ROC = 0.993, sensitivity = 0.989) and was suitable as the baseline model.
- Random forests perform close to logistic regression, but at a higher computational cost.
- KNN has a slightly lower specificity (0.888) and a slightly higher misclassification rate for benign samples.
5. Advanced Operations
Parameter Optimization
- Random Forest Tuning(optimization
mtry
parameters, i.e. the number of features per split):
model_rf_tuned <- train(
diagnosis ~ .,
data = data_training,
method = "rf",
metric = "ROC",
trControl = data_control,
tuneGrid = expand.grid(mtry = seq(5, 15, 2)) # 测试 mtry=5,7,...,15
)
- Extended Model: Add support vector machine (SVM)
model_svm <- train(
diagnosis ~ .,
data = data_training,
method = "svmRadial", # 径向核 SVM
metric = "ROC",
trControl = data_control
)
VI. Appendix
Common code quick lookup table
Function | Code Sample |
---|---|
Data Reading | read.csv("data.csv") |
Correlation heatmap | corrplot(cor(data), order = "hclust") |
10-fold cross validation setting | trainControl(method = "cv", number = 10) |
Confusion matrix calculation | confusionMatrix(pred, actual) |
PCA Dimensionality Reduction | prcomp(data, scale = TRUE) |
Common Problems Troubleshooting
- Report an error
could not find function "corrplot"
→ Solution: Installcorrplot
Bag(install.packages("corrplot")
). - Feature dimension error → Check if it is missing
select(-id, -diagnosis)
Steps (excluding non-feature columns). - Model training is slow → reduce
ntree
(number of random forest trees) ortuneLength
(number of parameter candidates).
Through this tutorial, you can master the machine learning process of medical binary classification problems, focus on understanding the core logic of feature screening, model tuning and result visualization, and provide a reference for modeling other disease diagnosis.