```/* \$Id: sparse_solve.c,v 1.9 2011/03/07 16:40:18 arif Exp \$ \$Revision: 1.9 \$ */
/* vim:set shiftwidth=4 ts=8: */

/*************************************************************************
* Copyright (c) 2011 AT&T Intellectual Property
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors: See CVS logs. Details at http://www.graphviz.org/
*************************************************************************/

#include <assert.h>
#include <string.h>
#include "sparse_solve.h"
#include "sfdpinternal.h"
#include "memory.h"
#include "logic.h"
#include "math.h"
#include "arith.h"
#include "types.h"
#include "globals.h"

#define DEBUG_PRINT

struct uniform_stress_matmul_data{
real alpha;
SparseMatrix A;
};

void Operator_uniform_stress_matmul_delete(Operator o){
FREE(o->data);
}

real *Operator_uniform_stress_matmul_apply(Operator o, real *x, real *y){
struct uniform_stress_matmul_data *d = (struct uniform_stress_matmul_data*) (o->data);
SparseMatrix A = d->A;
real alpha = d->alpha;
real xsum = 0.;
int m = A->m, i;

SparseMatrix_multiply_vector(A, x, &y, FALSE);

/* alpha*V*x */
for (i = 0; i < m; i++) xsum += x[i];

for (i = 0; i < m; i++) y[i] += alpha*(m*x[i] - xsum);

return y;
}

Operator Operator_uniform_stress_matmul(SparseMatrix A, real alpha){
Operator o;
struct uniform_stress_matmul_data *d;

o = MALLOC(sizeof(struct Operator_struct));
o->data = d = MALLOC(sizeof(struct uniform_stress_matmul_data));
d->alpha = alpha;
d->A = A;
o->Operator_apply = Operator_uniform_stress_matmul_apply;
return o;
}

real *Operator_matmul_apply(Operator o, real *x, real *y){
SparseMatrix A = (SparseMatrix) o->data;
SparseMatrix_multiply_vector(A, x, &y, FALSE);
return y;
}

Operator Operator_matmul_new(SparseMatrix A){
Operator o;

o = N_GNEW(1,struct Operator_struct);
o->data = (void*) A;
o->Operator_apply = Operator_matmul_apply;
return o;
}

void Operator_matmul_delete(Operator o){
if (o) FREE(o);
}

real* Operator_diag_precon_apply(Operator o, real *x, real *y){
int i, m;
real *diag = (real*) o->data;
m = (int) diag[0];
diag++;
for (i = 0; i < m; i++) y[i] = x[i]*diag[i];
return y;
}

Operator Operator_uniform_stress_diag_precon_new(SparseMatrix A, real alpha){
Operator o;
real *diag;
int i, j, m = A->m, *ia = A->ia, *ja = A->ja;
real *a = (real*) A->a;

assert(A->type == MATRIX_TYPE_REAL);

assert(a);

o = MALLOC(sizeof(struct Operator_struct));
o->data = MALLOC(sizeof(real)*(m + 1));
diag = (real*) o->data;

diag[0] = m;
diag++;
for (i = 0; i < m; i++){
diag[i] = 1./(m-1);
for (j = ia[i]; j < ia[i+1]; j++){
if (i == ja[j] && ABS(a[j]) > 0) diag[i] = 1./((m-1)*alpha+a[j]);
}
}

o->Operator_apply = Operator_diag_precon_apply;

return o;
}

Operator Operator_diag_precon_new(SparseMatrix A){
Operator o;
real *diag;
int i, j, m = A->m, *ia = A->ia, *ja = A->ja;
real *a = (real*) A->a;

assert(A->type == MATRIX_TYPE_REAL);

assert(a);

o = N_GNEW(1,struct Operator_struct);
o->data = N_GNEW((A->m + 1),real);
diag = (real*) o->data;

diag[0] = m;
diag++;
for (i = 0; i < m; i++){
diag[i] = 1.;
for (j = ia[i]; j < ia[i+1]; j++){
if (i == ja[j] && ABS(a[j]) > 0) diag[i] = 1./a[j];
}
}

o->Operator_apply = Operator_diag_precon_apply;

return o;
}

void Operator_diag_precon_delete(Operator o){
if (o->data) FREE(o->data);
if (o) FREE(o);
}

static real conjugate_gradient(Operator A, Operator precon, int n, real *x, real *rhs, real tol, int maxit, int *flag){
real *z, *r, *p, *q, res = 10*tol, alpha;
real rho = 1.0e20, rho_old = 1, res0, beta;
real* (*Ax)(Operator o, real *in, real *out) = A->Operator_apply;
real* (*Minvx)(Operator o, real *in, real *out) = precon->Operator_apply;
int iter = 0;

z = N_GNEW(n,real);
r = N_GNEW(n,real);
p = N_GNEW(n,real);
q = N_GNEW(n,real);

r = Ax(A, x, r);
r = vector_subtract_to(n, rhs, r);

res0 = res = sqrt(vector_product(n, r, r))/n;
#ifdef DEBUG_PRINT
if (Verbose && 0){
fprintf(stderr, "   cg iter = %d, residual = %g\n", iter, res);
}
#endif

while ((iter++) < maxit && res > tol*res0){
z = Minvx(precon, r, z);
rho = vector_product(n, r, z);

if (iter > 1){
beta = rho/rho_old;
p = vector_saxpy(n, z, p, beta);
} else {
MEMCPY(p, z, sizeof(real)*n);
}

q = Ax(A, p, q);

alpha = rho/vector_product(n, p, q);

x = vector_saxpy2(n, x, p, alpha);
r = vector_saxpy2(n, r, q, -alpha);

res = sqrt(vector_product(n, r, r))/n;

#ifdef DEBUG_PRINT
if (Verbose && 0){
fprintf(stderr, "   cg iter = %d, residual = %g\n", iter, res);
}
#endif

rho_old = rho;
}
FREE(z); FREE(r); FREE(p); FREE(q);
#ifdef DEBUG
_statistics[0] += iter - 1;
#endif

#ifdef DEBUG_PRINT
if (Verbose && 0){
fprintf(stderr, "   cg iter = %d, residual = %g\n", iter, res);
}
#endif
return res;
}

real cg(Operator Ax, Operator precond, int n, int dim, real *x0, real *rhs, real tol, int maxit, int *flag){
real *x, *b, res = 0;
int k, i;
x = N_GNEW(n, real);
b = N_GNEW(n, real);
for (k = 0; k < dim; k++){
for (i = 0; i < n; i++) {
x[i] = x0[i*dim+k];
b[i] = rhs[i*dim+k];
}

res += conjugate_gradient(Ax, precond, n, x, b, tol, maxit, flag);
for (i = 0; i < n; i++) {
rhs[i*dim+k] = x[i];
}
}
FREE(x);
FREE(b);
return res;

}

real SparseMatrix_solve(SparseMatrix A, int dim, real *x0, real *rhs, real tol, int maxit, int method, int *flag){
Operator Ax, precond;
int n = A->m;
real res = 0;
*flag = 0;

switch (method){
case SOLVE_METHOD_CG:
Ax =  Operator_matmul_new(A);
precond = Operator_diag_precon_new(A);
res = cg(Ax, precond, n, dim, x0, rhs, tol, maxit, flag);
Operator_matmul_delete(Ax);
Operator_diag_precon_delete(precond);
break;
default:
assert(0);
break;

}
return res;
}

```