XGBoost Survival Fix
... | ... |
@@ -19,16 +19,18 @@ extremeGradientBoostingTrainInterface <- function(measurementsTrain, outcomeTrai |
19 | 19 |
if(max(event) == 2) event <- event - 1 |
20 | 20 |
outcomeTrain <- time * ifelse(event == 1, 1, -1) # Negative for censoring. |
21 | 21 |
objective <- "survival:cox" |
22 |
+ trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds, |
|
23 |
+ colsample_bynode = mTryProportion, verbose = 0, ...) |
|
22 | 24 |
} else { # Classification task. |
23 | 25 |
isClassification <- TRUE |
24 | 26 |
classes <- levels(outcomeTrain) |
25 | 27 |
numClasses <- length(classes) |
26 | 28 |
objective <- "multi:softprob" |
27 | 29 |
outcomeTrain <- as.numeric(outcomeTrain) - 1 # Classes are represented as 0, 1, 2, ... |
28 |
- } |
|
29 |
- |
|
30 |
- trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds, |
|
30 |
+ trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds, |
|
31 | 31 |
num_class = numClasses, colsample_bynode = mTryProportion, verbose = 0, ...) |
32 |
+ } |
|
33 |
+ |
|
32 | 34 |
if(isClassification) |
33 | 35 |
{ |
34 | 36 |
attr(trained, "classes") <- classes # Useful for factor predictions in predict method. |