Browse code

Added faster perplexity for celda_G

Joshua D. Campbell authored on 14/03/2019 22:22:52
Showing 1 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,63 @@
1
+#include <R.h>
2
+#include <Rinternals.h>
3
+#include <R_ext/RS.h>
4
+
5
+SEXP _perplexityG(SEXP R_x, SEXP R_phi, SEXP R_psi, SEXP R_group)
6
+{
7
+  int i, j;
8
+  int nr = nrows(R_x);
9
+  int nc = ncols(R_x);
10
+  int nl = nlevels(R_group);
11
+  
12
+  // If the grouping variable is not a factor, throw an error
13
+  if (!isFactor(R_group)) {
14
+    error("The grouping argument must be a factor");
15
+  }
16
+  // If the length of the grouping variable and matrix do not match, throw an error
17
+  if (LENGTH(R_group) != nr) {
18
+    error("The length of the grouping argument must match the number of rows in the matrix.");
19
+  }
20
+  if (ncols(R_phi) != nc) {
21
+    error("The R_phi and R_x must have the same number of colums.");
22
+  }  
23
+  if (nrows(R_phi) != nl) {
24
+    error("R_phi must have the same number of rows as the number of levels in R_group.");
25
+  }  
26
+  if (nrows(R_psi) != nr) {
27
+    error("The R_psi and R_x must have the same number of rows.");
28
+  }  
29
+  if (ncols(R_psi) != nl) {
30
+    error("R_phi must have the same number of columns as the number of levels in R_group.");
31
+  }  
32
+  
33
+  // Create pointers
34
+  int *group = INTEGER(R_group);
35
+  double *phi = REAL(R_phi);
36
+  double *psi = REAL(R_psi);
37
+  int *x = INTEGER(R_x);  
38
+  
39
+  // Make sure values are not NA and within the range of the number of rows
40
+  for (i = 0; i < nr; i++) {
41
+    if(group[i] == NA_INTEGER || group[i] < 0 || group[i] > nr) {
42
+      error("Labels in group and pgroup must not be NA and must less than or equal to the number of rows in the matrix.");
43
+    }
44
+  }  
45
+  
46
+  // Allocate a variable for the return matrix
47
+  
48
+  
49
+  double ans = 0;
50
+  // Multiply the probabilties, log transform, and multiply against the counts to derive log(p(x))
51
+  for (j = 0; j < nc; j++) {
52
+    for (i = 0; i < nr; i++) {
53
+      ans += x[j * nr + i] * log(phi[j * nl + (group[i]-1)] * psi[nr * (group[i]-1) + i]);
54
+    }
55
+  }
56
+  
57
+  SEXP R_ans = PROTECT(allocVector(REALSXP, 1));
58
+  REAL(R_ans)[0] = ans;
59
+  
60
+  UNPROTECT(1);
61
+  return(R_ans);
62
+}  
63
+