Merge branch 'releases/Release_2_11_4_Branch'
[jalview.git] / src / jalview / analysis / ccAnalysis.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21
22 /*
23 * Copyright 2018-2022 Kathy Su, Kay Diederichs
24
25 * This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
26
27 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
28
29 * You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>. 
30 */
31
32 /**
33 * Ported from https://doi.org/10.1107/S2059798317000699 by
34 * @AUTHOR MorellThomas
35 */
36
37 package jalview.analysis;
38
39 import jalview.bin.Console;
40 import jalview.math.MatrixI;
41 import jalview.math.Matrix;
42 import jalview.math.MiscMath;
43
44 import java.lang.Math;
45 import java.lang.System;
46 import java.util.Arrays;
47 import java.util.ArrayList;
48 import java.util.Comparator;
49 import java.util.Map.Entry;
50 import java.util.TreeMap;
51
52 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
53 import org.apache.commons.math3.linear.SingularValueDecomposition;
54
55 /**
56  * A class to model rectangular matrices of double values and operations on them
57  */
58 public class ccAnalysis
59 {
60   private byte dim = 0; // dimensions
61
62   private MatrixI scoresOld; // input scores
63
64   public ccAnalysis(MatrixI scores, byte dim)
65   {
66     // round matrix to .4f to be same as in pasimap
67     for (int i = 0; i < scores.height(); i++)
68     {
69       for (int j = 0; j < scores.width(); j++)
70       {
71         if (!Double.isNaN(scores.getValue(i, j)))
72         {
73           scores.setValue(i, j,
74                   (double) Math.round(scores.getValue(i, j) * (int) 10000)
75                           / 10000);
76         }
77       }
78     }
79     this.scoresOld = scores;
80     this.dim = dim;
81   }
82
83   /**
84    * Initialise a distrust-score for each hypothesis (h) of hSigns distrust =
85    * conHypNum - proHypNum
86    *
87    * @param hSigns
88    *          ~ hypothesis signs (+/-) for each sequence
89    * @param scores
90    *          ~ input score matrix
91    *
92    * @return distrustScores
93    */
94   private int[] initialiseDistrusts(byte[] hSigns, MatrixI scores)
95   {
96     int[] distrustScores = new int[scores.width()];
97
98     // loop over symmetric matrix
99     for (int i = 0; i < scores.width(); i++)
100     {
101       byte hASign = hSigns[i];
102       int conHypNum = 0;
103       int proHypNum = 0;
104
105       for (int j = 0; j < scores.width(); j++)
106       {
107         double cell = scores.getRow(i)[j]; // value at [i][j] in scores
108         byte hBSign = hSigns[j];
109         if (!Double.isNaN(cell))
110         {
111           byte cellSign = (byte) Math.signum(cell); // check if sign of matrix
112                                                     // value fits hyptohesis
113           if (cellSign == hASign * hBSign)
114           {
115             proHypNum++;
116           }
117           else
118           {
119             conHypNum++;
120           }
121         }
122       }
123       distrustScores[i] = conHypNum - proHypNum; // create distrust score for
124                                                  // each sequence
125     }
126     return distrustScores;
127   }
128
129   /**
130    * Optimise hypothesis concerning the sign of the hypothetical value for each
131    * hSigns by interpreting the pairwise correlation coefficients as scalar
132    * products
133    *
134    * @param hSigns
135    *          ~ hypothesis signs (+/-)
136    * @param distrustScores
137    * @param scores
138    *          ~ input score matrix
139    *
140    * @return hSigns
141    */
142   private byte[] optimiseHypothesis(byte[] hSigns, int[] distrustScores,
143           MatrixI scores)
144   {
145     // get maximum distrust score
146     int[] maxes = MiscMath.findMax(distrustScores);
147     int maxDistrustIndex = maxes[0];
148     int maxDistrust = maxes[1];
149
150     // if hypothesis is not optimal yet
151     if (maxDistrust > 0)
152     {
153       // toggle sign for hI with maximum distrust
154       hSigns[maxDistrustIndex] *= -1;
155       // update distrust at same position
156       distrustScores[maxDistrustIndex] *= -1;
157
158       // also update distrust scores for all hI that were not changed
159       byte hASign = hSigns[maxDistrustIndex];
160       for (int NOTmaxDistrustIndex = 0; NOTmaxDistrustIndex < distrustScores.length; NOTmaxDistrustIndex++)
161       {
162         if (NOTmaxDistrustIndex != maxDistrustIndex)
163         {
164           byte hBSign = hSigns[NOTmaxDistrustIndex];
165           double cell = scores.getValue(maxDistrustIndex,
166                   NOTmaxDistrustIndex);
167
168           // distrust only changed if not NaN
169           if (!Double.isNaN(cell))
170           {
171             byte cellSign = (byte) Math.signum(cell);
172             // if sign of cell matches hypothesis decrease distrust by 2 because
173             // 1 more value supporting and 1 less contradicting
174             // else increase by 2
175             if (cellSign == hASign * hBSign)
176             {
177               distrustScores[NOTmaxDistrustIndex] -= 2;
178             }
179             else
180             {
181               distrustScores[NOTmaxDistrustIndex] += 2;
182             }
183           }
184         }
185       }
186       // further optimisation necessary
187       return optimiseHypothesis(hSigns, distrustScores, scores);
188
189     }
190     else
191     {
192       return hSigns;
193     }
194   }
195
196   /**
197    * takes the a symmetric MatrixI as input scores which may contain Double.NaN
198    * approximate the missing values using hypothesis optimisation
199    *
200    * runs analysis
201    *
202    * @param scores
203    *          ~ score matrix
204    *
205    * @return
206    */
207   public MatrixI run() throws Exception
208   {
209     // initialse eigenMatrix and repMatrix
210     MatrixI eigenMatrix = scoresOld.copy();
211     MatrixI repMatrix = scoresOld.copy();
212     try
213     {
214       /*
215       * Calculate correction factor for 2nd and higher eigenvalue(s).
216       * This correction is NOT needed for the 1st eigenvalue, because the
217       * unknown (=NaN) values of the matrix are approximated by presuming
218       * 1-dimensional vectors as the basis of the matrix interpretation as dot
219       * products.
220       */
221
222       System.out.println("Input correlation matrix:");
223       eigenMatrix.print(System.out, "%1.4f ");
224
225       int matrixWidth = eigenMatrix.width(); // square matrix, so width ==
226                                              // height
227       int matrixElementsTotal = (int) Math.pow(matrixWidth, 2); // total number
228                                                                 // of elemts
229
230       float correctionFactor = (float) (matrixElementsTotal
231               - eigenMatrix.countNaN()) / (float) matrixElementsTotal;
232
233       /*
234       * Calculate hypothetical value (1-dimensional vector) h_i for each
235       * dataset by interpreting the given correlation coefficients as scalar
236       * products.
237       */
238
239       /*
240       * Memory for current hypothesis concerning sign of each h_i.
241       * List of signs for all h_i in the encoding:
242       * *  1: positive
243       * *  0: zero
244       * * -1: negative
245       * Initial hypothesis: all signs are positive.
246       */
247       byte[] hSigns = new byte[matrixWidth];
248       Arrays.fill(hSigns, (byte) 1);
249
250       // Estimate signs for each h_i by refining hypothesis on signs.
251       hSigns = optimiseHypothesis(hSigns,
252               initialiseDistrusts(hSigns, eigenMatrix), eigenMatrix);
253
254       // Estimate absolute values for each h_i by determining sqrt of mean of
255       // non-NaN absolute values for every row.
256       double[] hAbs = MiscMath.sqrt(eigenMatrix.absolute().meanRow());
257
258       // Combine estimated signs with absolute values in obtain total value for
259       // each h_i.
260       double[] hValues = MiscMath.elementwiseMultiply(hSigns, hAbs);
261
262       /*Complement symmetric matrix by using the scalar products of estimated
263       *values of h_i to replace NaN-cells.
264       *Matrix positions that have estimated values
265       *(only for diagonal and upper off-diagonal values, due to the symmetry
266       *the positions of the lower-diagonal values can be inferred).
267       List of tuples (row_idx, column_idx).*/
268
269       ArrayList<int[]> estimatedPositions = new ArrayList<int[]>();
270
271       // for off-diagonal cells
272       for (int rowIndex = 0; rowIndex < matrixWidth - 1; rowIndex++)
273       {
274         for (int columnIndex = rowIndex
275                 + 1; columnIndex < matrixWidth; columnIndex++)
276         {
277           double cell = eigenMatrix.getValue(rowIndex, columnIndex);
278           if (Double.isNaN(cell))
279           {
280             // calculate scalar product as new cell value
281             cell = hValues[rowIndex] * hValues[columnIndex];
282             // fill in new value in cell and symmetric partner
283             eigenMatrix.setValue(rowIndex, columnIndex, cell);
284             eigenMatrix.setValue(columnIndex, rowIndex, cell);
285             // save positions of estimated values
286             estimatedPositions.add(new int[] { rowIndex, columnIndex });
287           }
288         }
289       }
290
291       // for diagonal cells
292       for (int diagonalIndex = 0; diagonalIndex < matrixWidth; diagonalIndex++)
293       {
294         double cell = Math.pow(hValues[diagonalIndex], 2);
295         eigenMatrix.setValue(diagonalIndex, diagonalIndex, cell);
296         estimatedPositions.add(new int[] { diagonalIndex, diagonalIndex });
297       }
298
299       /*Refine total values of each h_i:
300       *Initialise h_values of the hypothetical non-existant previous iteration
301       *with the correct format but with impossible values.
302        Needed for exit condition of otherwise endless loop.*/
303       System.out.print("initial values: [ ");
304       for (double h : hValues)
305       {
306         System.out.print(String.format("%1.4f, ", h));
307       }
308       System.out.println(" ]");
309
310       double[] hValuesOld = new double[matrixWidth];
311
312       int iterationCount = 0;
313
314       // FIXME JAL-4443 - spliterators could be coded out or patched with j2s
315       // annotation
316       // repeat unitl values of h do not significantly change anymore
317       while (true)
318       {
319         for (int hIndex = 0; hIndex < matrixWidth; hIndex++)
320         {
321           double newH = Arrays
322                   .stream(MiscMath.elementwiseMultiply(hValues,
323                           eigenMatrix.getRow(hIndex)))
324                   .sum()
325                   / Arrays.stream(
326                           MiscMath.elementwiseMultiply(hValues, hValues))
327                           .sum();
328           hValues[hIndex] = newH;
329         }
330
331         System.out.print(String.format("iteration %d: [ ", iterationCount));
332         for (double h : hValues)
333         {
334           System.out.print(String.format("%1.4f, ", h));
335         }
336         System.out.println(" ]");
337
338         // update values of estimated positions
339         for (int[] pair : estimatedPositions) // pair ~ row, col
340         {
341           double newVal = hValues[pair[0]] * hValues[pair[1]];
342           eigenMatrix.setValue(pair[0], pair[1], newVal);
343           eigenMatrix.setValue(pair[1], pair[0], newVal);
344         }
345
346         iterationCount++;
347
348         // exit loop as soon as new values are similar to the last iteration
349         if (MiscMath.allClose(hValues, hValuesOld, 0d, 1e-05d, false))
350         {
351           break;
352         }
353
354         // save hValues for comparison in the next iteration
355         System.arraycopy(hValues, 0, hValuesOld, 0, hValues.length);
356       }
357
358       // -----------------------------
359       // Use complemented symmetric matrix to calculate final representative
360       // vectors.
361
362       // Eigendecomposition.
363       eigenMatrix.tred();
364       eigenMatrix.tqli();
365
366       System.out.println("eigenmatrix");
367       eigenMatrix.print(System.out, "%8.2f");
368       System.out.println();
369       System.out.println("uncorrected eigenvalues");
370       eigenMatrix.printD(System.out, "%2.4f ");
371       System.out.println();
372
373       double[] eigenVals = eigenMatrix.getD();
374
375       TreeMap<Double, Integer> eigenPairs = new TreeMap<>(
376               Comparator.reverseOrder());
377       for (int i = 0; i < eigenVals.length; i++)
378       {
379         eigenPairs.put(eigenVals[i], i);
380       }
381
382       // matrix of representative eigenvectors (each row is a vector)
383       double[][] _repMatrix = new double[eigenVals.length][dim];
384       double[][] _oldMatrix = new double[eigenVals.length][dim];
385       double[] correctedEigenValues = new double[dim];
386
387       int l = 0;
388       for (Entry<Double, Integer> pair : eigenPairs.entrySet())
389       {
390         double eigenValue = pair.getKey();
391         int column = pair.getValue();
392         double[] eigenVector = eigenMatrix.getColumn(column);
393         // for 2nd and higher eigenvalues
394         if (l >= 1)
395         {
396           eigenValue /= correctionFactor;
397         }
398         correctedEigenValues[l] = eigenValue;
399         for (int j = 0; j < eigenVector.length; j++)
400         {
401           _repMatrix[j][l] = (eigenValue < 0) ? 0.0
402                   : -Math.sqrt(eigenValue) * eigenVector[j];
403           double tmpOldScore = scoresOld.getColumn(column)[j];
404           _oldMatrix[j][dim - l - 1] = (Double.isNaN(tmpOldScore)) ? 0.0
405                   : tmpOldScore;
406         }
407         l++;
408         if (l >= dim)
409         {
410           break;
411         }
412       }
413
414       System.out.println("correctedEigenValues");
415       MiscMath.print(correctedEigenValues, "%2.4f ");
416
417       repMatrix = new Matrix(_repMatrix);
418       repMatrix.setD(correctedEigenValues);
419       MatrixI oldMatrix = new Matrix(_oldMatrix);
420
421       MatrixI dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
422
423       double rmsd = scoresOld.rmsd(dotMatrix);
424
425       System.out.println("iteration, rmsd, maxDiff, rmsdDiff");
426       System.out.println(String.format("0, %8.5f, -, -", rmsd));
427       // Refine representative vectors by minimising sum-of-squared deviates
428       // between dotMatrix and original score matrix
429       for (int iteration = 1; iteration < 21; iteration++) // arbitrarily set to
430                                                            // 20
431       {
432         MatrixI repMatrixOLD = repMatrix.copy();
433         MatrixI dotMatrixOLD = dotMatrix.copy();
434
435         // for all rows/hA in the original matrix
436         for (int hAIndex = 0; hAIndex < oldMatrix.height(); hAIndex++)
437         {
438           double[] row = oldMatrix.getRow(hAIndex);
439           double[] hA = repMatrix.getRow(hAIndex);
440           hAIndex = hAIndex;
441           // find least-squares-solution fo rdifferences between original scores
442           // and representative vectors
443           double[] hAlsm = leastSquaresOptimisation(repMatrix, scoresOld,
444                   hAIndex);
445           // update repMatrix with new hAlsm
446           for (int j = 0; j < repMatrix.width(); j++)
447           {
448             repMatrix.setValue(hAIndex, j, hAlsm[j]);
449           }
450         }
451
452         // dot product of representative vecotrs yields a matrix with values
453         // approximating the correlation matrix
454         dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
455         // calculate rmsd between approximation and correlation matrix
456         rmsd = scoresOld.rmsd(dotMatrix);
457
458         // calculate maximum change of representative vectors of current
459         // iteration
460         MatrixI diff = repMatrix.subtract(repMatrixOLD).absolute();
461         double maxDiff = 0.0;
462         for (int i = 0; i < diff.height(); i++)
463         {
464           for (int j = 0; j < diff.width(); j++)
465           {
466             maxDiff = (diff.getValue(i, j) > maxDiff) ? diff.getValue(i, j)
467                     : maxDiff;
468           }
469         }
470
471         // calculate rmsd between current and previous estimation
472         double rmsdDiff = dotMatrix.rmsd(dotMatrixOLD);
473
474         System.out.println(String.format("%d, %8.5f, %8.5f, %8.5f",
475                 iteration, rmsd, maxDiff, rmsdDiff));
476
477         if (!(Math.abs(maxDiff) > 1e-06))
478         {
479           repMatrix = repMatrixOLD.copy();
480           break;
481         }
482       }
483
484     } catch (Exception q)
485     {
486       Console.error("Error computing cc_analysis:  " + q.getMessage());
487       q.printStackTrace();
488     }
489     System.out.println("final coordinates:");
490     repMatrix.print(System.out, "%1.8f ");
491     return repMatrix;
492   }
493
494   /**
495    * Create equations system using information on originally known pairwise
496    * correlation coefficients (parsed from infile) and the representative result
497    * vectors
498    *
499    * Each equation has the format:
500    * 
501    * hA * hA - pairwiseCC = 0
502    * 
503    * with:
504    * 
505    * hA: unknown variable
506    * 
507    * hB: known representative vector
508    * 
509    * pairwiseCC: known pairwise correlation coefficien
510    * 
511    * The resulting equations system is overdetermined, if there are more
512    * equations than unknown elements
513    *
514    * x is the user input. Remaining parameters are needed for generating
515    * equations system, NOT to be specified by user).
516    * 
517    * @param x
518    *          ~ unknown n-dimensional column-vector
519    * @param hAIndex
520    *          ~ index of currently optimised representative result vector.
521    * @param h
522    *          ~ matrix with row-wise listing of representative result vectors.
523    * @param originalRow
524    *          ~ matrix-row of originally parsed pairwise correlation
525    *          coefficients.
526    *
527    * @return
528    */
529   private double[] originalToEquasionSystem(double[] hA, MatrixI repMatrix,
530           MatrixI scoresOld, int hAIndex)
531   {
532     double[] originalRow = scoresOld.getRow(hAIndex);
533     int nans = MiscMath.countNaN(originalRow);
534     double[] result = new double[originalRow.length - nans];
535
536     // for all pairwiseCC in originalRow
537     int resultIndex = 0;
538     for (int hBIndex = 0; hBIndex < originalRow.length; hBIndex++)
539     {
540       double pairwiseCC = originalRow[hBIndex];
541       // if not NaN -> create new equation and add it to the system
542       if (!Double.isNaN(pairwiseCC))
543       {
544         double[] hB = repMatrix.getRow(hBIndex);
545         result[resultIndex++] = MiscMath
546                 .sum(MiscMath.elementwiseMultiply(hA, hB)) - pairwiseCC;
547       }
548       else
549       {
550       }
551     }
552     return result;
553   }
554
555   /**
556    * returns the jacobian matrix
557    * 
558    * @param repMatrix
559    *          ~ matrix of representative vectors
560    * @param hAIndex
561    *          ~ current row index
562    *
563    * @return
564    */
565   private MatrixI approximateDerivative(MatrixI repMatrix,
566           MatrixI scoresOld, int hAIndex)
567   {
568     // hA = x0
569     double[] hA = repMatrix.getRow(hAIndex);
570     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld,
571             hAIndex);
572     double[] signX0 = new double[hA.length];
573     double[] xAbs = new double[hA.length];
574     for (int i = 0; i < hA.length; i++)
575     {
576       signX0[i] = (hA[i] >= 0) ? 1 : -1;
577       xAbs[i] = (Math.abs(hA[i]) >= 1.0) ? Math.abs(hA[i]) : 1.0;
578     }
579     double rstep = Math.pow(Math.ulp(1.0), 0.5);
580
581     double[] h = new double[hA.length];
582     for (int i = 0; i < hA.length; i++)
583     {
584       h[i] = rstep * signX0[i] * xAbs[i];
585     }
586
587     int m = f0.length;
588     int n = hA.length;
589     double[][] jTransposed = new double[n][m];
590     for (int i = 0; i < h.length; i++)
591     {
592       double[] x = new double[h.length];
593       System.arraycopy(hA, 0, x, 0, h.length);
594       x[i] += h[i];
595       double dx = x[i] - hA[i];
596       double[] df = originalToEquasionSystem(x, repMatrix, scoresOld,
597               hAIndex);
598       for (int j = 0; j < df.length; j++)
599       {
600         df[j] -= f0[j];
601         jTransposed[i][j] = df[j] / dx;
602       }
603     }
604     MatrixI J = new Matrix(jTransposed).transpose();
605     return J;
606   }
607
608   /**
609    * norm of regularized (by alpha) least-squares solution minus Delta
610    * 
611    * @param alpha
612    * @param suf
613    * @param s
614    * @param Delta
615    *
616    * @return
617    */
618   private double[] phiAndDerivative(double alpha, double[] suf, double[] s,
619           double Delta)
620   {
621     double[] denom = MiscMath
622             .elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha);
623     double pNorm = MiscMath.norm(MiscMath.elementwiseDivide(suf, denom));
624     double phi = pNorm - Delta;
625     // - sum ( suf**2 / denom**3) / pNorm
626     double phiPrime = -MiscMath.sum(MiscMath.elementwiseDivide(
627             MiscMath.elementwiseMultiply(suf, suf),
628             MiscMath.elementwiseMultiply(
629                     MiscMath.elementwiseMultiply(denom, denom), denom)))
630             / pNorm;
631     return new double[] { phi, phiPrime };
632   }
633
634   /**
635    * class holding the result of solveLsqTrustRegion
636    */
637   private class TrustRegion
638   {
639     private double[] step;
640
641     private double alpha;
642
643     private int iteration;
644
645     public TrustRegion(double[] step, double alpha, int iteration)
646     {
647       this.step = step;
648       this.alpha = alpha;
649       this.iteration = iteration;
650     }
651
652     public double[] getStep()
653     {
654       return this.step;
655     }
656
657     public double getAlpha()
658     {
659       return this.alpha;
660     }
661
662     public int getIteration()
663     {
664       return this.iteration;
665     }
666   }
667
668   /**
669    * solve a trust-region problem arising in least-squares optimisation
670    * 
671    * @param n
672    *          ~ number of variables
673    * @param m
674    *          ~ number of residuals
675    * @param uf
676    * @param s
677    *          ~ singular values of J
678    * @param V
679    *          ~ transpose of VT
680    * @param Delta
681    *          ~ radius of a trust region
682    * @param alpha
683    *          ~ initial guess for alpha
684    *
685    * @return
686    */
687   private TrustRegion solveLsqTrustRegion(int n, int m, double[] uf,
688           double[] s, MatrixI V, double Delta, double alpha)
689   {
690     double[] suf = MiscMath.elementwiseMultiply(s, uf);
691
692     // check if J has full rank and tr Gauss-Newton step
693     boolean fullRank = false;
694     if (m >= n)
695     {
696       double threshold = s[0] * Math.ulp(1.0) * m;
697       fullRank = s[s.length - 1] > threshold;
698     }
699     if (fullRank)
700     {
701       double[] p = MiscMath.elementwiseMultiply(
702               V.sumProduct(MiscMath.elementwiseDivide(uf, s)), -1);
703       if (MiscMath.norm(p) <= Delta)
704       {
705         TrustRegion result = new TrustRegion(p, 0.0, 0);
706         return result;
707       }
708     }
709
710     double alphaUpper = MiscMath.norm(suf) / Delta;
711     double alphaLower = 0.0;
712     if (fullRank)
713     {
714       double[] phiAndPrime = phiAndDerivative(0.0, suf, s, Delta);
715       alphaLower = -phiAndPrime[0] / phiAndPrime[1];
716     }
717
718     alpha = (!fullRank && alpha == 0.0)
719             ? alpha = Math.max(0.001 * alphaUpper,
720                     Math.pow(alphaLower * alphaUpper, 0.5))
721             : alpha;
722
723     int iteration = 0;
724     while (iteration < 10) // 10 is default max_iter
725     {
726       alpha = (alpha < alphaLower || alpha > alphaUpper)
727               ? alpha = Math.max(0.001 * alphaUpper,
728                       Math.pow(alphaLower * alphaUpper, 0.5))
729               : alpha;
730       double[] phiAndPrime = phiAndDerivative(alpha, suf, s, Delta);
731       double phi = phiAndPrime[0];
732       double phiPrime = phiAndPrime[1];
733
734       alphaUpper = (phi < 0) ? alpha : alphaUpper;
735       double ratio = phi / phiPrime;
736       alphaLower = Math.max(alphaLower, alpha - ratio);
737       alpha -= (phi + Delta) * ratio / Delta;
738
739       if (Math.abs(phi) < 0.01 * Delta) // default rtol set to 0.01
740       {
741         break;
742       }
743       iteration++;
744     }
745
746     // p = - V.dot( suf / (s**2 + alpha))
747     double[] tmp = MiscMath.elementwiseDivide(suf, MiscMath
748             .elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha));
749     double[] p = MiscMath.elementwiseMultiply(V.sumProduct(tmp), -1);
750
751     // Make the norm of p equal to Delta, p is changed only slightly during
752     // this.
753     // It is done to prevent p lie outside of the trust region
754     p = MiscMath.elementwiseMultiply(p, Delta / MiscMath.norm(p));
755
756     TrustRegion result = new TrustRegion(p, alpha, iteration + 1);
757     return result;
758   }
759
760   /**
761    * compute values of a quadratic function arising in least squares function:
762    * 0.5 * s.T * (J.T * J + diag) * s + g.T * s
763    *
764    * @param J
765    *          ~ jacobian matrix
766    * @param g
767    *          ~ gradient
768    * @param s
769    *          ~ steps and rows
770    *
771    * @return
772    */
773   private double evaluateQuadratic(MatrixI J, double[] g, double[] s)
774   {
775
776     double[] Js = J.sumProduct(s);
777     double q = MiscMath.dot(Js, Js);
778     double l = MiscMath.dot(s, g);
779
780     return 0.5 * q + l;
781   }
782
783   /**
784    * update the radius of a trust region based on the cost reduction
785    *
786    * @param Delta
787    * @param actualReduction
788    * @param predictedReduction
789    * @param stepNorm
790    * @param boundHit
791    *
792    * @return
793    */
794   private double[] updateTrustRegionRadius(double Delta,
795           double actualReduction, double predictedReduction,
796           double stepNorm, boolean boundHit)
797   {
798     double ratio = 0;
799     if (predictedReduction > 0)
800     {
801       ratio = actualReduction / predictedReduction;
802     }
803     else if (predictedReduction == 0 && actualReduction == 0)
804     {
805       ratio = 1;
806     }
807     else
808     {
809       ratio = 0;
810     }
811
812     if (ratio < 0.25)
813     {
814       Delta = 0.25 * stepNorm;
815     }
816     else if (ratio > 0.75 && boundHit)
817     {
818       Delta *= 2.0;
819     }
820
821     return new double[] { Delta, ratio };
822   }
823
824   /**
825    * trust region reflective algorithm
826    * 
827    * @param repMatrix
828    *          ~ Matrix containing representative vectors
829    * @param scoresOld
830    *          ~ Matrix containing initial observations
831    * @param index
832    *          ~ current row index
833    * @param J
834    *          ~ jacobian matrix
835    *
836    * @return
837    */
838   private double[] trf(MatrixI repMatrix, MatrixI scoresOld, int index,
839           MatrixI J)
840   {
841     // hA = x0
842     double[] hA = repMatrix.getRow(index);
843     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, index);
844     int nfev = 1;
845     int m = J.height();
846     int n = J.width();
847     double cost = 0.5 * MiscMath.dot(f0, f0);
848     double[] g = J.transpose().sumProduct(f0);
849     double Delta = MiscMath.norm(hA);
850     int maxNfev = hA.length * 100;
851     double alpha = 0.0; // "Levenberg-Marquardt" parameter
852
853     double gNorm = 0;
854     boolean terminationStatus = false;
855     int iteration = 0;
856
857     while (true)
858     {
859       gNorm = MiscMath.norm(g);
860       if (terminationStatus || nfev == maxNfev)
861       {
862         break;
863       }
864       SingularValueDecomposition svd = new SingularValueDecomposition(
865               new Array2DRowRealMatrix(J.asArray()));
866       MatrixI U = new Matrix(svd.getU().getData());
867       double[] s = svd.getSingularValues();
868       MatrixI V = new Matrix(svd.getV().getData()).transpose();
869       double[] uf = U.transpose().sumProduct(f0);
870
871       double actualReduction = -1;
872       double[] xNew = new double[hA.length];
873       double[] fNew = new double[f0.length];
874       double costNew = 0;
875       double stepHnorm = 0;
876
877       while (actualReduction <= 0 && nfev < maxNfev)
878       {
879         TrustRegion trustRegion = solveLsqTrustRegion(n, m, uf, s, V, Delta,
880                 alpha);
881         double[] stepH = trustRegion.getStep();
882         alpha = trustRegion.getAlpha();
883         int nIterations = trustRegion.getIteration();
884         double predictedReduction = -(evaluateQuadratic(J, g, stepH));
885
886         xNew = MiscMath.elementwiseAdd(hA, stepH);
887         fNew = originalToEquasionSystem(xNew, repMatrix, scoresOld, index);
888         nfev++;
889
890         stepHnorm = MiscMath.norm(stepH);
891
892         if (MiscMath.countNaN(fNew) > 0)
893         {
894           Delta = 0.25 * stepHnorm;
895           continue;
896         }
897
898         // usual trust-region step quality estimation
899         costNew = 0.5 * MiscMath.dot(fNew, fNew);
900         actualReduction = cost - costNew;
901
902         double[] updatedTrustRegion = updateTrustRegionRadius(Delta,
903                 actualReduction, predictedReduction, stepHnorm,
904                 stepHnorm > (0.95 * Delta));
905         double DeltaNew = updatedTrustRegion[0];
906         double ratio = updatedTrustRegion[1];
907
908         // default ftol and xtol = 1e-8
909         boolean ftolSatisfied = actualReduction < (1e-8 * cost)
910                 && ratio > 0.25;
911         boolean xtolSatisfied = stepHnorm < (1e-8
912                 * (1e-8 + MiscMath.norm(hA)));
913         terminationStatus = ftolSatisfied || xtolSatisfied;
914         if (terminationStatus)
915         {
916           break;
917         }
918
919         alpha *= Delta / DeltaNew;
920         Delta = DeltaNew;
921
922       }
923       if (actualReduction > 0)
924       {
925         hA = xNew;
926         f0 = fNew;
927         cost = costNew;
928
929         J = approximateDerivative(repMatrix, scoresOld, index);
930
931         g = J.transpose().sumProduct(f0);
932       }
933       else
934       {
935         stepHnorm = 0;
936         actualReduction = 0;
937       }
938       iteration++;
939     }
940
941     return hA;
942   }
943
944   /**
945    * performs the least squares optimisation adapted from
946    * https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.least_squares.html#scipy.optimize.least_squares
947    *
948    * @param repMatrix
949    *          ~ Matrix containing representative vectors
950    * @param scoresOld
951    *          ~ Matrix containing initial observations
952    * @param index
953    *          ~ current row index
954    *
955    * @return
956    */
957   private double[] leastSquaresOptimisation(MatrixI repMatrix,
958           MatrixI scoresOld, int index)
959   {
960     MatrixI J = approximateDerivative(repMatrix, scoresOld, index);
961     double[] result = trf(repMatrix, scoresOld, index, J);
962     return result;
963   }
964
965 }