基于神经网络的Cox比例风险模型(DeepSurv)在生存分析中的应用
基于神经网络的Cox比例风险模型(DeepSurv)在生存分析中的应用
简 介
背景
医疗从业者使用生存模型来探索和理解患者协变量(如临床和遗传特征)与各种治疗方案有效性之间的关系。标准的生存模型,如线性Cox比例风险模型,需要广泛的特征工程或先前的医学知识来模拟个体水平上的治疗相互作用。虽然非线性生存方法,如神经网络和生存森林,可以固有地模拟这些高级交互术语,但它们尚未被证明是有效的治疗推荐系统。
方法
我们引入 DeepSurv,一种 Cox 比例风险深度神经网络和最先进的生存方法,用于模拟患者协变量与治疗效果之间的相互作用,从而提供个性化治疗建议。
结果
在模拟和真实的生存数据上进行了大量的实验训练 DeepSurv。证明 DeepSurv 的表现与其他最先进的生存模型一样好,甚至更好,并验证 DeepSurv 成功地模拟了患者协变量与其失败风险之间日益复杂的关系。然后,展示了 DeepSurv 如何模拟患者特征与不同治疗方案有效性之间的关系,以展示如何使用 DeepSurv 提供个性化治疗建议。最后,在真实的临床研究中训练 DeepSurv,以证明它的个性化治疗建议将如何增加一组患者的生存时间。
结论
DeepSurv 的预测和建模能力将使医学研究人员能够使用深度神经网络作为探索、理解和预测患者特征对失败风险影响的工具。
软件包安装
survivalmodels包使用reticulate从Python实现模型。为了使用这些模型,必须按照reticulate::py_install安装所需的Python包。Survivalmodels包含一个辅助函数,用于安装所需的pycox函数(如果还需要,则使用pytorch)。在运行此包中的任何模型之前,如果您尚未安装pycox,请运行。
install_pycox(pip = TRUE, install_torch = FALSE)
然后再次安装survivalmodels:
remotes::install_github("RaphaelS1/survivalmodels")
install.packages("survivalmodels")
数据读取
library(survival)
library(sampling)
dim(lung)
## [1] 228 10
lung = na.omit(lung)
table(lung$status)
##
## 1 2
## 47 120
set.seed(123)
# 每层抽取70%的数据
train_id <- strata(lung, "status", size = rev(round(table(lung$status) * 0.7)))$ID_unit
# 训练数据
trainData <- lung[train_id, ]
# 测试数据
testData <- lung[-train_id, ]
实例操作
构建模型
library(survivalmodels)
set.seed(123)
fit <- deepsurv(Surv(time, status) ~ ., data = trainData, frac = 0.3, activation = "relu",
num_nodes = c(4L, 8L, 4L, 2L), dropout = 0.1, early_stopping = TRUE, epochs = 100L,
batch_size = 32L)
验证
预测predict()参数中type有三种选择:"survival", "risk", "all", 可以获得预测生存矩阵和相对风险:
pred <- predict(fit, testData, type = "survival")
str(pred)
## num [1:50, 1:76] 0.988 0.988 0.988 0.988 0.988 ...
## - attr(*, "dimnames")=List of 2
## ..$ : chr [1:50] "0 " "1 " "2 " "3 " ...
## ..$ : chr [1:76] "11" "15" "26" "53" ...
一致性
survivalmodels包含了一致性分析的函数cindex(),跟survival包里面的survival::concordance()使用非常相似。
p <- predict(fit, type = "risk", newdata = testData)
cindex(risk = p, truth = testData[, "time"])
## [1] 0.4877451
生存分析
根据风险值我们可以将患者分为高低风险组,然后绘制生存曲线:
library(survminer)
testData$risk = p
group = ifelse(testData$risk > mean(testData$risk), "High", "Low")
f <- survfit(Surv(testData$time, testData$status) ~ group)
f
## Call: survfit(formula = Surv(testData$time, testData$status) ~ group)
##
## n events median 0.95LCL 0.95UCL
## group=High 31 21 390 239 687
## group=Low 19 15 267 167 814
ggsurvplot(f, data = testData, surv.median.line = "hv", legend.title = "Risk Group",
legend.labs = c("Low Risk", "High Risk"), pval = TRUE, ggtheme = theme_bw())
绘制ROC曲线
由于我们所作的模型根时间密切相关因此我们选择timeROC,可以快速的技术出来不同时期的ROC,进一步作图:
library(timeROC)
res <- timeROC(T = testData$time, delta = testData$status, marker = testData$risk,
cause = 1, weighting = "marginal", times = c(1 * 365, 2 * 365), ROC = TRUE, iid = TRUE)
res$AUC_1
## t=365 t=730
## 0.5592593 0.7400000
confint(res, level = 0.95)$CI_AUC
## NULL
plot(res, time = 1 * 365, col = "red", title = FALSE, lwd = 2)
plot(res, time = 2 * 365, add = TRUE, col = "blue", lwd = 2)
legend("bottomright", c(paste("1 Years ", round(res$AUC_1[1], 2)), paste("2 Years ",
round(res$AUC_1[2], 2))), col = c("red", "blue", "green"), lty = 1, lwd = 2)
不同时间节点的AUC曲线及其置信区间
再分析不同时间节点的AUC曲线及其置信区间,由于数据量非常小,此图并不明显。
plotAUCcurve(res, conf.int = TRUE, col = "red")
Reference
Katzman, J. L., Shaham, U., Cloninger, A., Bates, J., Jiang, T., & Kluger, Y. (2018). DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Medical Research Methodology, 18(1), 24. https://doi.org/10.1186/s12874-018-0482-1