```/*
* This file is part of Isolator.
*
* Copyright (c) 2011 by Daniel C. Jones <dcjones@cs.washington.edu>
*
*/

#include "motif.hpp"
#include "logger.hpp"
#include "miscmath.h"
#include <cmath>
#include <cstdio>
#include <cstring>

using namespace std;

/* return 4^x */
static unsigned int four_pow(unsigned int x)
{
return 1 << (2 * x);
}

/* Bayesian (Schwarz) Information Criterion */
static double bic(double L, double n_obs, double n_params, double c = 1.0)
{
return L - c * n_params * log(n_obs);

// I'm fudging this a bit. Technically, it should be:
//       2.0 * L - c * n_params * log(n_obs)
}

/* A bunch of primative vector/matrix operations */

static void colcpy(double* dest, const double* src, size_t j, size_t n, size_t m)
{
for(size_t i = 0; i < n; ++i) dest[i] = src[i * m + j];
}

static void vecadd(double* u, double* v, size_t n)
{
for (size_t i = 0; i < n; ++i) u[i] += v[i];
}

static void vecsub(double* u, double* v, size_t n)
{
for (size_t i = 0; i < n; ++i) u[i] -= v[i];
}

static void vecaddcol(double* u, double* V, size_t n, size_t m, size_t j)
{
for (size_t i = 0; i < n; ++i) u[i] += V[i * m + j];
}

static void vecsubcol(double* u, double* V, size_t n, size_t m, size_t j)
{
for (size_t i = 0; i < n; ++i) u[i] -= V[i * m + j];
}

static void matsetcol(double* U, double* v, size_t n, size_t m, size_t j)
{
for (size_t i = 0; i < n; ++i) U[i * m + j] = v[i];
}

/* I've moved all of this code out of the motif class, because I think it makes
* things a bit cleaner. */
class motif_trainer
{
public:

motif_trainer(const deque<twobitseq*>& seqs0,
const deque<twobitseq*>& seqs1,
size_t m,
size_t max_parents,
size_t max_distance,
double complexity_penalty)
: col(0)
, col_max(30)
, m(m)
, max_parents(max_parents)
, max_distance(max_distance)
, complexity_penalty(complexity_penalty)
{
M.m = m;
M.P0 = new kmer_matrix(m, max_parents + 1);
M.P1 = new kmer_matrix(m, max_parents + 1);

M.P0->set_all(0.0);
M.P1->set_all(0.0);

M.parents = new bool [m * m];
memset(M.parents, 0, m * m * sizeof(bool));

sprintf(col_base, "\n%30s", "");

n0 = seqs0.size();
n1 = seqs1.size();
n  = n0 + n1;

seqs.insert(seqs.begin(), seqs1.begin(), seqs1.end());
seqs.insert(seqs.begin(), seqs0.begin(), seqs0.end());

reachability = new bool [m * m];

L0 = new double [n * m];
memset(L0, 0, n * m * sizeof(double));

L1 = new double [n * m];
memset(L1, 0, n * m * sizeof(double));

ms0 = new double [n];
memset(ms0, 0, n * sizeof(double));

ms1 = new double [n];
memset(ms1, 0, n * sizeof(double));

ls0_i = new double [n];
ls0_j = new double [n];
ls1_i = new double [n];
ls1_j = new double [n];

ps0_i = new double [M.P0->ncols()];
ps0_j = new double [M.P0->ncols()];
ps1_i = new double [M.P1->ncols()];
ps1_j = new double [M.P1->ncols()];

prior = (double) n1 / (double) (n0 + n1);
};

~motif_trainer()
{
delete [] reachability;

delete [] L0;
delete [] L1;

delete [] ms0;
delete [] ms1;

delete [] ls0_i;
delete [] ls0_j;
delete [] ls1_i;
delete [] ls1_j;

delete [] ps0_i;
delete [] ps0_j;
delete [] ps1_i;
delete [] ps1_j;
}

void train()
{
logger::suspend();

char msg[512];

/* How good the current solution is. */
double ll = conditional_likelihood();
double ic = bic(ll, n, M.num_params(), complexity_penalty);

/* Keep track of the best move found so far. */
move_type move_best = MOVE_NA;
int i_best = -1;
int j_best = -1;
double ic_best;

int i_search, j_search;
double ic_search;

size_t round_num = 0;

while (true) {
++round_num;
compute_reachability();

logger::print(col_base);
snprintf(msg, sizeof(msg), "round %4lu (ic = %0.4e)",
(unsigned long)round_num, ic);
logger::print(msg);
logger::print(col_base);
col = 0;

ic_best = -HUGE_VAL;

if (ic_search > ic_best) {
ic_best = ic_search;
i_best  = i_search;
j_best  = j_search;
}

search_subtractions(i_search, j_search, ic_search);

if (ic_search > ic_best) {
ic_best = ic_search;
i_best  = i_search;
j_best  = j_search;
move_best = MOVE_SUBTRACTION;
}

search_reversals(i_search, j_search, ic_search);

if (ic_search > ic_best) {
ic_best = ic_search;
i_best  = i_search;
j_best  = j_search;
move_best = MOVE_REVERSAL;
}

logger::print("\n");

if (ic_best <= ic) break;

/* make the best move */

logger::print(col_base);
switch (move_best) {
logger::print(" [+] ");
break;

case MOVE_SUBTRACTION:
logger::print(" [-] ");
del_edge(i_best, j_best);
break;

case MOVE_REVERSAL:
logger::print(" [.] ");
del_edge(i_best, j_best);

vecsubcol(ms0, L0, n, m, i_best);
vecsubcol(ms1, L1, n, m, i_best);

update_likelihood_column(i_best);

break;

default:
}

vecsubcol(ms0, L0, n, m, j_best);
vecsubcol(ms1, L1, n, m, j_best);

update_likelihood_column(j_best);

snprintf(msg, sizeof(msg), "%d->%d\n", i_best, j_best);
logger::print(msg);

ic = ic_best;
}

logger::resume();
}

private:

enum move_type {
MOVE_NA,
MOVE_SUBTRACTION,
MOVE_REVERSAL
};

double evaluate_move(int i, int j, move_type move)
{
double ll;
double ic;

/* keep track of old parameters to avoid retraining */
M.P0->get_row(j, ps0_j);
M.P1->get_row(j, ps1_j);
if (move == MOVE_REVERSAL) {
M.P0->get_row(i, ps0_i);
M.P1->get_row(i, ps1_i);
}

/* keep track of old likelihoods to avoid reevaluating */
colcpy(ls0_j, L0, j, n, m);
colcpy(ls1_j, L1, j, n, m);
if (move == MOVE_REVERSAL) {
colcpy(ls0_i, L0, i, n, m);
colcpy(ls1_i, L1, i, n, m);
}

/* make a move */
switch (move) {
logger::print("+");
break;

case MOVE_SUBTRACTION:
logger::print("-");
del_edge(i, j);
break;

case MOVE_REVERSAL:
logger::print(".");
rev_edge(i, j);
break;

default:
}

if (++col > col_max) {
col = 0;
logger::print(col_base);
}

/* evaluate likelihood */
update_likelihood_column(j);
if (move == MOVE_REVERSAL) {
update_likelihood_column(i);
}

/* update training example likelihoods */
vecsub(ms0, ls0_j, n);
vecsub(ms1, ls1_j, n);

if (move == MOVE_REVERSAL) {
vecsub(ms0, ls0_i, n);
vecsub(ms1, ls1_i, n);
}

/* compute likelihood and ic */
ll = conditional_likelihood();
ic = bic(ll, n, M.num_params(), complexity_penalty);

/* undo the move */
switch (move) {
M.set_edge(i, j, false);
break;

case MOVE_SUBTRACTION:
M.set_edge(i, j, true);
break;

case MOVE_REVERSAL:
M.set_edge(j, i, false);
M.set_edge(i, j, true);
break;

default:
}

/* restore previous parameters */
M.P0->set_row(j, ps0_j);
M.P1->set_row(j, ps1_j);
if (move == MOVE_REVERSAL) {
M.P0->set_row(i, ps0_i);
M.P1->set_row(i, ps1_i);
}

/* restore previous likelihoods */
vecsubcol(ms0, L0, n, m, j);
vecsubcol(ms1, L1, n, m, j);
if (move == MOVE_REVERSAL) {
vecsubcol(ms0, L0, n, m, i);
vecsubcol(ms1, L1, n, m, i);
}

matsetcol(L0, ls0_j, n, m, j);
matsetcol(L1, ls1_j, n, m, j);
if (move == MOVE_REVERSAL) {
matsetcol(L0, ls0_i, n, m, i);
matsetcol(L1, ls1_i, n, m, i);
}

return ic;
}

int& i_best,
int& j_best,
double& ic_best)
{
int i, j;

double ic;

i_best = 0;
j_best = 0;
ic_best = -HUGE_VAL;

/* specifies the range of i's that we search through */
int i_start, i_end;

for (j = m - 1; j >= 0; --j) {

/* determine the range of i's that form possible (i, j) edges */
if (M.has_edge(j, j)) {
if (max_distance == 0 || (size_t) j < max_distance) i_start = 0;
else i_start = j - max_distance;

if (max_distance == 0) i_end = m - 1;
else i_end = min(m, j + max_distance);
}
else i_start = i_end = j;

for (i = i_start; i <= i_end; ++i) {
/* skip existing edges */
if (M.has_edge(i, j)) continue;

/* skip edges that would introduce cycles */
if (reachable(j, i)) continue;

/* skip edges that would exceed the parent limit */
if (M.num_parents(j) >= max_parents) continue;

/* skip edges that are equivalent to one already tried:
* (If i and j both have in-degree 0, and i > j, then we already
* tried the edge (j, i).)
*/
if (i > j && M.num_parents(j) == 1 && M.num_parents(i) == 1) continue;

if (ic >= ic_best) {
i_best = i;
j_best = j;
ic_best = ic;
}
}
}
}

void search_subtractions(
int& i_best,
int& j_best,
double& ic_best)
{
int i, j;
double ic;

i_best = 0;
j_best = 0;
ic_best = -HUGE_VAL;

for (j = 0; (size_t) j < m; ++j) {
for (i = 0; (size_t) i < m; ++i) {

/* skip non-existent edges */
if (!M.has_edge(i, j)) continue;

/* skip nodes with in-edges */
if (i == j && M.num_parents(j) > 1) continue;

ic = evaluate_move(i, j, MOVE_SUBTRACTION);

if (ic >= ic_best) {
i_best = i;
j_best = j;
ic_best = ic;
}
}
}
}

void search_reversals(
int& i_best,
int& j_best,
double& ic_best)
{
int i, j;
double ic;
bool has_ij_path;

i_best = 0;
j_best = 0;
ic_best = -HUGE_VAL;

for (j = 0; (size_t) j < m; ++j) {
for (i = 0; (size_t) i < m; ++i) {

/* skip nodes */
if (i == j) continue;

/* skip non-existent edges */
if (!M.has_edge(i, j)) continue;

/* skip reversals that add parameters */
if (!M.has_edge(i, i) || !M.has_edge(j, j)) continue;

/* skip reversals that would introduce cycles
* (this is determined by the cutting the (i, j) edge,
* recomputing reachability and checking if i, j remains) */
M.set_edge(i, j, false);
compute_reachability();

has_ij_path = reachable(i, j);

M.set_edge(i, j, true);
compute_reachability();

if (has_ij_path) continue;

ic = evaluate_move(i, j, MOVE_REVERSAL);

if (ic >= ic_best) {
i_best = i;
j_best = j;
ic_best = ic;
}
}
}
}

void compute_reachability()
{
/* find the transitive closure of the parents matrix via flyod-warshall */

memcpy(reachability, M.parents, m * m * sizeof(bool));

size_t k, i, j;
for (k = 0; k < m; ++k) {
for (i = 0; i < m; ++i) {
for (j = 0; j < m; ++j) {
reachability[j * m + i] =
reachability[j * m + i] ||
(reachability[k * m + i] && reachability[j * m + k]);
}
}
}
}

double conditional_likelihood()
{
double ll = 0.0;
size_t i;

double p = log(prior);
double q = log(1.0 - prior);

/* background sequence likelihood */
for (i = 0; i < n0; ++i) {
ll += (q + ms0[i]) - logaddexp(q + ms0[i], p + ms1[i]);
}

/* foreground sequnece likelihood */
for (i = n0; i < n; ++i) {
ll += (p + ms1[i]) - logaddexp(q + ms0[i], p + ms1[i]);
}

return ll;
}

void update_likelihood_column(int j)
{
size_t i;
kmer K;

// zero columns
for (i = 0; i < n; ++i) {
L0[i * m + j] = 0.0;
L1[i * m + j] = 0.0;
}

// compute likelihood
deque<twobitseq*>::const_iterator seq;
for (i = 0, seq = seqs.begin(); seq != seqs.end(); ++i, ++seq) {
if ((*seq)->make_kmer(K, 0, M.parents + j * m, m) > 0) {
L0[i * m + j] = (*M.P0)(j, K);
L1[i * m + j] = (*M.P1)(j, K);
}
}
}

void calc_row_params(int j)
{
/* initialize */

M.P0->set_row(j, 0.0);
M.P1->set_row(j, 0.0);

size_t np = M.num_parents(j);

/* If this node is now removed, do nothing further. */
if (np == 0) return;

size_t K_max = four_pow(np);
kmer K;

for (K = 0; K < K_max; ++K) {
(*M.P0)(j, K) = pseudocount;
(*M.P1)(j, K) = pseudocount;
}

/* tabulate */

size_t c;
deque<twobitseq*>::const_iterator seq;
for (c = 0, seq = seqs.begin(); seq != seqs.end(); ++c, ++seq) {

if ((*seq)->make_kmer(K, 0, M.parents + j * m, m) == 0) continue;

/* background sequence */
if (c < n0) (*M.P0)(j, K) += 1;

/* foreground sequence */
else (*M.P1)(j, K) += 1;
}

/* normalize */

size_t np_pred = 0; // predecessor parents
int u;
for (u = 0; u < j; ++u) {
if (M.has_edge(u, j)) ++np_pred;
}

M.P0->make_conditional_distribution(j, np_pred, np);
M.P1->make_conditional_distribution(j, np_pred, np);

M.P0->dist_log_transform_row(j, np);
M.P1->dist_log_transform_row(j, np);
}

{
M.set_edge(i, j, true);
calc_row_params(j);
}

void del_edge(int i, int j)
{
M.set_edge(i, j, false);
calc_row_params(j);
}

void rev_edge(int i, int j)
{
del_edge(i, j);
}

bool reachable(int i, int j)
{
return reachability[j * m + i];
}

public:
motif M;

private:
deque<twobitseq*> seqs;

/* for printing pretty output */
int col;
int col_max;
char col_base[80];

/** a n*n 0-1 matrix marking if node i is reachable from node j */
bool* reachability;

/* number of training examples (foreground, background, and total, resp.) */
size_t n0, n1, n;

/* number of positions */
size_t m;

double prior;

size_t max_parents;
size_t max_distance;
double complexity_penalty;

double* L0;
double* L1;

double* ms0;
double* ms1;

double* ls0_i;
double* ls0_j;
double* ls1_i;
double* ls1_j;

double* ps0_i;
double* ps0_j;
double* ps1_i;
double* ps1_j;

/* Initialize outside class definition */
static const double pseudocount;
};
const double motif_trainer::pseudocount = 1.0;

motif::motif()
: m(0)
, P0(NULL)
, P1(NULL)
, parents(NULL)
{

}

motif::motif(const motif& other)
{
m = other.m;
P0 = new kmer_matrix(*other.P0);
P1 = new kmer_matrix(*other.P0);
parents = new bool [m * m];
memcpy(parents, other.parents, m * m * sizeof(bool));
}

motif::motif(const deque<twobitseq*>& seqs0,
const deque<twobitseq*>& seqs1,
size_t m,
size_t max_parents,
size_t max_distance,
double complexity_penalty)
{
/* We are essentially outsourcing construction here. The motif_trainer
* object does all the construction, because it is complicated, and I want
* to encapsulate it a bit. */

motif_trainer trainer(seqs0, seqs1,
m,
max_parents,
max_distance,
complexity_penalty);

trainer.train();

this->m = trainer.M.m;
P0 = new kmer_matrix(*trainer.M.P0);
P1 = new kmer_matrix(*trainer.M.P1);
parents = new bool [m * m];
memcpy(parents, trainer.M.parents, m * m * sizeof(bool));
}

motif::motif(const YAML::Node& node)
{
/* m */
unsigned int m_;
node["m"] >> m_;
m = (size_t) m_;

/* parents */
parents = new bool [m * m];
memset(parents, 0, m * m * sizeof(bool));
const YAML::Node& parents_node = node["parents"];

size_t i;
int b;
for (i = 0; i < m * m; ++i) {
parents_node[i] >> b;
parents[i] = (bool) b;
}

/* P0 */
const YAML::Node& P0_node = node["P0"];
P0 = new kmer_matrix(P0_node);

/* P1 */
const YAML::Node& P1_node = node["P1"];
P1 = new kmer_matrix(P1_node);
}

void motif::to_yaml(YAML::Emitter& out) const
{
out << YAML::BeginMap;

/* m */
out << YAML::Key   << "m";
out << YAML::Value << (unsigned int) m;

/* parents */
out << YAML::Key << "parents";
out << YAML::Value;
out << YAML::Flow << YAML::BeginSeq;

size_t i;
for (i = 0; i < m * m; ++i) out << (parents[i] ? 1 : 0);

out << YAML::EndSeq;

/* background parameters */
out << YAML::Key << "P0";
out << YAML::Value;
P0->to_yaml(out);

/* foreground parameters */
out << YAML::Key << "P1";
out << YAML::Value;
P1->to_yaml(out);

out << YAML::EndMap;
}

string motif::model_graph(int offset) const
{
string graph_str;
char strbuf[512];

graph_str += "digraph {\n";
graph_str += "splines=\"true\";\n";
graph_str += "node [shape=\"box\"];\n";

/* print nodes */
size_t i, j;
for (j = 0; j < m; j++) {
snprintf(strbuf, sizeof(strbuf),
"n%d [label=\"%d\",pos=\"%d,0\",style=\"%s\"];\n",
(int) j, (int) j - offset, (int) j * 100,
parents[j * m + j] ? "solid" : "dotted");

graph_str += strbuf;
}

/* print edges */
for (j = 0; j < m; ++j) {
if (!parents[j * m + j]) continue;

for (i = 0; i < m; ++i) {
if (i == j) continue;
if (parents[j * m + i]) {
snprintf(strbuf, sizeof(strbuf),
"n%lu -> n%lu;\n", (unsigned long) i, (unsigned long) j);
graph_str += strbuf;
}
}
}

graph_str += "}\n";
return graph_str;
}

double motif::eval(const twobitseq& seq, size_t offset)
{
double ll0 = 0.0;
double ll1 = 0.0;
kmer K;

size_t n = P0->nrows();
size_t i;
for (i = 0; i < n; ++i) {
if (seq.make_kmer(K, offset, parents + i * m, m) == 0) continue;

ll0 += (*P0)(i, K);
ll1 += (*P1)(i, K);
}

return exp(ll1 - ll0);
}

void motif::set_edge(int i, int j, bool x)
{
parents[j * m + i] = x;
}

bool motif::has_edge(int i, int j) const
{
return parents[j * m + i];
}

size_t motif::num_parents(size_t j) const
{
size_t i;
size_t N = 0;
for (i = 0; i < m; ++i) if (has_edge(i, j)) ++N;

return N;
}

size_t motif::num_params() const
{
size_t N =0;
int i;
for (i = 0; (size_t) i < m; ++i) {
N += four_pow(num_parents(i)) - 1;
}

N *= 2; /* foreground and background parameters */

return N;
}

motif::~motif()
{
delete P0;
delete P1;
}

```