Browse code

Merge pull request #50 from SydneyBioX/master

XGBoost Survival Fix

Dario Strbenac authored on 24/10/2022 22:36:18 • GitHub committed on 24/10/2022 22:36:18
Showing 1 changed files

... ...
@@ -19,16 +19,18 @@ extremeGradientBoostingTrainInterface <- function(measurementsTrain, outcomeTrai
19 19
     if(max(event) == 2) event <- event - 1
20 20
     outcomeTrain <- time * ifelse(event == 1, 1, -1) # Negative for censoring.
21 21
     objective <- "survival:cox"
22
+    trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds,
23
+                                colsample_bynode = mTryProportion, verbose = 0, ...)
22 24
   } else { # Classification task.
23 25
     isClassification <- TRUE
24 26
     classes <- levels(outcomeTrain)
25 27
     numClasses <- length(classes)
26 28
     objective <- "multi:softprob"
27 29
     outcomeTrain <- as.numeric(outcomeTrain) - 1 # Classes are represented as 0, 1, 2, ...
28
-  }
29
-  
30
-  trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds,
30
+    trained <- xgboost::xgboost(measurementsTrain, outcomeTrain, objective = objective, nrounds = nrounds,
31 31
                               num_class = numClasses, colsample_bynode = mTryProportion, verbose = 0, ...)
32
+  }
33
+
32 34
   if(isClassification)
33 35
   {
34 36
     attr(trained, "classes") <- classes # Useful for factor predictions in predict method.