f7d5c58f40dc2f2a81fc4c843436424a4eed1272
[jalview.git] / src / jalview / analysis / ccAnalysis.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21
22 /*
23 * Copyright 2018-2022 Kathy Su, Kay Diederichs
24
25 * This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
26
27 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
28
29 * You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>. 
30 */
31
32 /**
33 * Ported from https://doi.org/10.1107/S2059798317000699 by
34 * @AUTHOR MorellThomas
35 */
36
37 package jalview.analysis;
38
39 import jalview.bin.Console;
40 import jalview.math.MatrixI;
41 import jalview.math.Matrix;
42 import jalview.math.MiscMath;
43
44 import java.lang.Math;
45 import java.lang.System;
46 import java.util.Arrays;
47 import java.util.ArrayList;
48 import java.util.Comparator;
49 import java.util.Map.Entry;
50 import java.util.TreeMap;
51
52 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
53 import org.apache.commons.math3.linear.SingularValueDecomposition;
54
55 /**
56  * A class to model rectangular matrices of double values and operations on them
57  */
58 public class ccAnalysis 
59 {
60   private byte dim = 0;         //dimensions
61
62   private MatrixI scoresOld;    //input scores
63
64   public ccAnalysis(MatrixI scores, byte dim)
65   {
66     // round matrix to .4f to be same as in pasimap
67     for (int i = 0; i < scores.height(); i++)
68     {
69       for (int j = 0; j < scores.width(); j++)
70       {
71         if (!Double.isNaN(scores.getValue(i,j)))
72         {
73           scores.setValue(i, j, (double) Math.round(scores.getValue(i,j) * (int) 10000) / 10000);
74         }
75       }
76     }
77     this.scoresOld = scores;
78     this.dim = dim;
79   }
80
81   /** 
82   * Initialise a distrust-score for each hypothesis (h) of hSigns
83   * distrust = conHypNum - proHypNum
84   *
85   * @param hSigns ~ hypothesis signs (+/-) for each sequence
86   * @param scores ~ input score matrix
87   *
88   * @return distrustScores
89   */
90   private int[] initialiseDistrusts(byte[] hSigns, MatrixI scores)
91   {
92     int[] distrustScores = new int[scores.width()];
93     
94     // loop over symmetric matrix
95     for (int i = 0; i < scores.width(); i++)
96     {
97       byte hASign = hSigns[i];
98       int conHypNum = 0;
99       int proHypNum = 0;
100
101       for (int j = 0; j < scores.width(); j++)
102       {
103         double cell = scores.getRow(i)[j];      // value at [i][j] in scores
104         byte hBSign = hSigns[j];
105         if (!Double.isNaN(cell))
106         {
107           byte cellSign = (byte) Math.signum(cell);     //check if sign of matrix value fits hyptohesis
108           if (cellSign == hASign * hBSign)
109           {
110             proHypNum++;
111           } else {
112             conHypNum++;
113           }
114         }
115       }
116       distrustScores[i] = conHypNum - proHypNum;        //create distrust score for each sequence
117     }
118     return distrustScores;
119   }
120
121   /**
122   * Optemise hypothesis concerning the sign of the hypothetical value for each hSigns by interpreting the pairwise correlation coefficients as scalar products
123   *
124   * @param hSigns ~ hypothesis signs (+/-)
125   * @param distrustScores
126   * @param scores ~ input score matrix
127   *
128   * @return hSigns
129   */
130   private byte[] optimiseHypothesis(byte[] hSigns, int[] distrustScores, MatrixI scores)
131   {
132     // get maximum distrust score
133     int[] maxes = MiscMath.findMax(distrustScores);
134     int maxDistrustIndex = maxes[0];
135     int maxDistrust = maxes[1];
136
137     // if hypothesis is not optimal yet
138     if (maxDistrust > 0)
139     {
140       //toggle sign for hI with maximum distrust
141       hSigns[maxDistrustIndex] *= -1;
142       // update distrust at same position
143       distrustScores[maxDistrustIndex] *= -1;
144
145       // also update distrust scores for all hI that were not changed
146       byte hASign = hSigns[maxDistrustIndex];
147       for (int NOTmaxDistrustIndex = 0; NOTmaxDistrustIndex < distrustScores.length; NOTmaxDistrustIndex++)
148       {
149         if (NOTmaxDistrustIndex != maxDistrustIndex)
150         {
151           byte hBSign = hSigns[NOTmaxDistrustIndex];
152           double cell = scores.getValue(maxDistrustIndex, NOTmaxDistrustIndex);
153
154           // distrust only changed if not NaN
155           if (!Double.isNaN(cell))
156           {
157             byte cellSign = (byte) Math.signum(cell);
158             // if sign of cell matches hypothesis decrease distrust by 2 because 1 more value supporting and 1 less contradicting
159             // else increase by 2
160             if (cellSign == hASign * hBSign)
161             {
162               distrustScores[NOTmaxDistrustIndex] -= 2;
163             } else {
164               distrustScores[NOTmaxDistrustIndex] += 2;
165             }
166           }
167         }
168       }
169       //further optimisation necessary
170       return optimiseHypothesis(hSigns, distrustScores, scores);
171
172     } else {
173       return hSigns;
174     }
175   }
176
177   /** 
178   * takes the a symmetric MatrixI as input scores which may contain Double.NaN 
179   * approximate the missing values using hypothesis optimisation 
180   *
181   * runs analysis
182   *
183   * @param scores ~ score matrix
184   *
185   * @return
186   */
187   public MatrixI run () throws Exception
188   {
189     //initialse eigenMatrix and repMatrix
190     MatrixI eigenMatrix = scoresOld.copy();
191     MatrixI repMatrix = scoresOld.copy();
192     try
193     {
194     /*
195     * Calculate correction factor for 2nd and higher eigenvalue(s).
196     * This correction is NOT needed for the 1st eigenvalue, because the
197     * unknown (=NaN) values of the matrix are approximated by presuming
198     * 1-dimensional vectors as the basis of the matrix interpretation as dot
199     * products.
200     */
201         
202     System.out.println("Input correlation matrix:");
203     eigenMatrix.print(System.out, "%1.4f ");
204
205     int matrixWidth = eigenMatrix.width(); // square matrix, so width == height
206     int matrixElementsTotal = (int) Math.pow(matrixWidth, 2);   //total number of elemts
207
208     float correctionFactor = (float) (matrixElementsTotal - eigenMatrix.countNaN()) / (float) matrixElementsTotal;
209     
210     /*
211     * Calculate hypothetical value (1-dimensional vector) h_i for each
212     * dataset by interpreting the given correlation coefficients as scalar
213     * products.
214     */
215
216     /*
217     * Memory for current hypothesis concerning sign of each h_i.
218     * List of signs for all h_i in the encoding:
219       * *  1: positive
220       * *  0: zero
221       * * -1: negative
222     * Initial hypothesis: all signs are positive.
223     */
224     byte[] hSigns = new byte[matrixWidth];
225     Arrays.fill(hSigns, (byte) 1);
226
227     //Estimate signs for each h_i by refining hypothesis on signs.
228     hSigns = optimiseHypothesis(hSigns, initialiseDistrusts(hSigns, eigenMatrix), eigenMatrix);
229
230
231     //Estimate absolute values for each h_i by determining sqrt of mean of
232     //non-NaN absolute values for every row.
233     double[] hAbs = MiscMath.sqrt(eigenMatrix.absolute().meanRow());
234
235     //Combine estimated signs with absolute values in obtain total value for
236     //each h_i.
237     double[] hValues = MiscMath.elementwiseMultiply(hSigns, hAbs);
238
239     /*Complement symmetric matrix by using the scalar products of estimated
240     *values of h_i to replace NaN-cells.
241     *Matrix positions that have estimated values
242     *(only for diagonal and upper off-diagonal values, due to the symmetry
243     *the positions of the lower-diagonal values can be inferred).
244     List of tuples (row_idx, column_idx).*/
245
246     ArrayList<int[]> estimatedPositions = new ArrayList<int[]>();
247
248     // for off-diagonal cells
249     for (int rowIndex = 0; rowIndex < matrixWidth - 1; rowIndex++)
250     {
251       for (int columnIndex = rowIndex + 1; columnIndex < matrixWidth; columnIndex++)
252       {
253         double cell = eigenMatrix.getValue(rowIndex, columnIndex);
254         if (Double.isNaN(cell))
255         {
256           //calculate scalar product as new cell value
257           cell = hValues[rowIndex] * hValues[columnIndex];
258           //fill in new value in cell and symmetric partner
259           eigenMatrix.setValue(rowIndex, columnIndex, cell);
260           eigenMatrix.setValue(columnIndex, rowIndex, cell);
261           //save positions of estimated values
262           estimatedPositions.add(new int[]{rowIndex, columnIndex});
263         }
264       }
265     }
266
267     // for diagonal cells
268     for (int diagonalIndex = 0; diagonalIndex < matrixWidth; diagonalIndex++)
269       {
270         double cell = Math.pow(hValues[diagonalIndex], 2);
271         eigenMatrix.setValue(diagonalIndex, diagonalIndex, cell);
272         estimatedPositions.add(new int[]{diagonalIndex, diagonalIndex});
273       }
274
275     /*Refine total values of each h_i:
276     *Initialise h_values of the hypothetical non-existant previous iteration
277     *with the correct format but with impossible values.
278      Needed for exit condition of otherwise endless loop.*/
279     System.out.print("initial values: [ ");
280     for (double h : hValues)
281     {
282       System.out.print(String.format("%1.4f, ", h));
283     }
284     System.out.println(" ]");
285
286
287     double[] hValuesOld = new double[matrixWidth];
288
289     int iterationCount = 0;
290
291     // repeat unitl values of h do not significantly change anymore
292     while (true)
293     {
294       for (int hIndex = 0; hIndex < matrixWidth; hIndex++)
295       {
296         double newH = Arrays.stream(MiscMath.elementwiseMultiply(hValues, eigenMatrix.getRow(hIndex))).sum() / Arrays.stream(MiscMath.elementwiseMultiply(hValues, hValues)).sum();
297         hValues[hIndex] = newH;
298       }
299
300       System.out.print(String.format("iteration %d: [ ", iterationCount));
301       for (double h : hValues)
302       {
303         System.out.print(String.format("%1.4f, ", h));
304       }
305       System.out.println(" ]");
306
307       //update values of estimated positions
308       for (int[] pair : estimatedPositions)     // pair ~ row, col
309       {
310         double newVal = hValues[pair[0]] * hValues[pair[1]];
311         eigenMatrix.setValue(pair[0], pair[1], newVal);
312         eigenMatrix.setValue(pair[1], pair[0], newVal);
313       }
314
315       iterationCount++;
316
317       //exit loop as soon as new values are similar to the last iteration
318       if (MiscMath.allClose(hValues, hValuesOld, 0d, 1e-05d, false))
319       {
320         break;
321       }
322
323       //save hValues for comparison in the next iteration
324       System.arraycopy(hValues, 0, hValuesOld, 0, hValues.length);
325     }
326
327     //-----------------------------
328     //Use complemented symmetric matrix to calculate final representative
329     //vectors.
330
331     //Eigendecomposition.
332     eigenMatrix.tred();
333     eigenMatrix.tqli();
334
335     System.out.println("eigenmatrix");
336     eigenMatrix.print(System.out, "%8.2f");
337     System.out.println();
338     System.out.println("uncorrected eigenvalues");
339     eigenMatrix.printD(System.out, "%2.4f ");
340
341     double[] eigenVals = eigenMatrix.getD();
342
343     TreeMap<Double, Integer> eigenPairs = new TreeMap<>(Comparator.reverseOrder());
344     for (int i = 0; i < eigenVals.length; i++)
345     {
346       eigenPairs.put(eigenVals[i], i);
347     }
348
349     // matrix of representative eigenvectors (each row is a vector)
350     double[][] _repMatrix = new double[eigenVals.length][dim];
351     double[][] _oldMatrix = new double[eigenVals.length][dim];
352     double[] correctedEigenValues = new double[dim];    
353
354     int l = 0;
355     for (Entry<Double, Integer> pair : eigenPairs.entrySet())
356     {
357       double eigenValue = pair.getKey();
358       int column = pair.getValue();
359       double[] eigenVector = eigenMatrix.getColumn(column);
360       //for 2nd and higher eigenvalues
361       if (l >= 1)
362       {
363         eigenValue /= correctionFactor;
364       }
365       correctedEigenValues[l] = eigenValue;
366       for (int j = 0; j < eigenVector.length; j++)
367       {
368         _repMatrix[j][l] = (eigenValue < 0) ? 0.0 : - Math.sqrt(eigenValue) * eigenVector[j];
369         double tmpOldScore = scoresOld.getColumn(column)[j];
370         _oldMatrix[j][dim - l - 1] = (Double.isNaN(tmpOldScore)) ? 0.0 : tmpOldScore;
371       }
372       l++;
373       if (l >= dim)
374       {
375         break;
376       }
377     }
378
379     System.out.println("correctedEigenValues");
380     MiscMath.print(correctedEigenValues, "%2.4f ");
381
382     repMatrix = new Matrix(_repMatrix);
383     repMatrix.setD(correctedEigenValues);
384     MatrixI oldMatrix = new Matrix(_oldMatrix);
385
386     MatrixI dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
387     
388     double rmsd = scoresOld.rmsd(dotMatrix);
389
390     System.out.println("iteration, rmsd, maxDiff, rmsdDiff");
391     System.out.println(String.format("0, %8.5f, -, -", rmsd));
392     // Refine representative vectors by minimising sum-of-squared deviates between dotMatrix and original  score matrix
393     for (int iteration = 1; iteration < 21; iteration++)        // arbitrarily set to 20
394     {
395       MatrixI repMatrixOLD = repMatrix.copy();
396       MatrixI dotMatrixOLD = dotMatrix.copy();
397
398       // for all rows/hA in the original matrix
399       for (int hAIndex = 0; hAIndex < oldMatrix.height(); hAIndex++)
400       {
401         double[] row = oldMatrix.getRow(hAIndex);
402         double[] hA = repMatrix.getRow(hAIndex);
403         hAIndex = hAIndex;
404         //find least-squares-solution fo rdifferences between original scores and representative vectors
405         double[] hAlsm = leastSquaresOptimisation(repMatrix, scoresOld, hAIndex);
406         // update repMatrix with new hAlsm
407         for (int j = 0; j < repMatrix.width(); j++)
408         {
409           repMatrix.setValue(hAIndex, j, hAlsm[j]);
410         }
411       }
412       
413       // dot product of representative vecotrs yields a matrix with values approximating the correlation matrix
414       dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
415       // calculate rmsd between approximation and correlation matrix
416       rmsd = scoresOld.rmsd(dotMatrix);
417
418       // calculate maximum change of representative vectors of current iteration
419       MatrixI diff = repMatrix.subtract(repMatrixOLD).absolute();
420       double maxDiff = 0.0;
421       for (int i = 0; i < diff.height(); i++)
422       {
423         for (int j = 0; j < diff.width(); j++)
424         {
425           maxDiff = (diff.getValue(i, j) > maxDiff) ? diff.getValue(i, j) : maxDiff;
426         }
427       }
428
429       // calculate rmsd between current and previous estimation
430       double rmsdDiff = dotMatrix.rmsd(dotMatrixOLD);
431
432       System.out.println(String.format("%d, %8.5f, %8.5f, %8.5f", iteration, rmsd, maxDiff, rmsdDiff));
433
434       if (!(Math.abs(maxDiff) > 1e-06))
435       {
436         repMatrix = repMatrixOLD.copy();
437         break;
438       }
439     }
440     
441
442     } catch (Exception q)
443     {
444       Console.error("Error computing cc_analysis:  " + q.getMessage());
445       q.printStackTrace();
446     }
447     System.out.println("final coordinates:");
448     repMatrix.print(System.out, "%8.2f");
449     return repMatrix;
450   }
451
452   /**
453   * Create equations system using information on originally known
454   * pairwise correlation coefficients (parsed from infile) and the
455   * representative result vectors
456   *
457   * Each equation has the format:
458   * hA * hA - pairwiseCC = 0
459   * with:
460   * hA: unknown variable
461   * hB: known representative vector
462   * pairwiseCC: known pairwise correlation coefficien
463   * 
464   * The resulting equations system is overdetermined, if there are more
465   * equations than unknown elements
466   *
467   * @param x ~ unknown n-dimensional column-vector
468   * (needed for generating equations system, NOT to be specified by user).
469   * @param hAIndex ~ index of currently optimised representative result vector.
470   * @param h ~ matrix with row-wise listing of representative result vectors.
471   * @param originalRow ~ matrix-row of originally parsed pairwise correlation coefficients.
472   *
473   * @return
474   */
475   private double[] originalToEquasionSystem(double[] hA, MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
476   {
477     double[] originalRow = scoresOld.getRow(hAIndex);
478     int nans = MiscMath.countNaN(originalRow);
479     double[] result = new double[originalRow.length - nans];
480
481     //for all pairwiseCC in originalRow
482     int resultIndex = 0;
483     for (int hBIndex = 0; hBIndex < originalRow.length; hBIndex++)
484     {
485       double pairwiseCC = originalRow[hBIndex];
486       // if not NaN -> create new equation and add it to the system
487       if (!Double.isNaN(pairwiseCC))
488       {
489         double[] hB = repMatrix.getRow(hBIndex);
490         result[resultIndex++] = MiscMath.sum(MiscMath.elementwiseMultiply(hA, hB)) - pairwiseCC;
491       } else {
492       }
493     }
494     return result;
495   }
496
497   /**
498   * returns the jacobian matrix
499   * @param repMatrix ~ matrix of representative vectors
500   * @param hAIndex ~ current row index
501   *
502   * @return
503   */
504   private MatrixI approximateDerivative(MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
505   {
506     //hA = x0
507     double[] hA = repMatrix.getRow(hAIndex);
508     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, hAIndex);
509     double[] signX0 = new double[hA.length];
510     double[] xAbs = new double[hA.length];
511     for (int i = 0; i < hA.length; i++)
512     {
513       signX0[i] = (hA[i] >= 0) ? 1 : -1;
514       xAbs[i] = (Math.abs(hA[i]) >= 1.0) ? Math.abs(hA[i]) : 1.0;
515       }
516     double rstep = Math.pow(Math.ulp(1.0), 0.5);
517
518     double[] h = new double [hA.length];
519     for (int i = 0; i < hA.length; i++)
520     {
521       h[i] = rstep * signX0[i] * xAbs[i];
522     }
523       
524     int m = f0.length;
525     int n = hA.length;
526     double[][] jTransposed = new double[n][m];
527     for (int i = 0; i < h.length; i++)
528     {
529       double[] x = new double[h.length];
530       System.arraycopy(hA, 0, x, 0, h.length);
531       x[i] += h[i];
532       double dx = x[i] - hA[i];
533       double[] df = originalToEquasionSystem(x, repMatrix, scoresOld, hAIndex);
534       for (int j = 0; j < df.length; j++)
535       {
536         df[j] -= f0[j];
537         jTransposed[i][j] = df[j] / dx;
538       }
539     }
540     MatrixI J = new Matrix(jTransposed).transpose();
541     return J;
542   }
543
544   /**
545   * norm of regularized (by alpha) least-squares solution minus Delta
546   * @param alpha
547   * @param suf
548   * @param s
549   * @param Delta
550   *
551   * @return
552   */
553   private double[] phiAndDerivative(double alpha, double[] suf, double[] s, double Delta)
554   {
555     double[] denom = MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha);
556     double pNorm = MiscMath.norm(MiscMath.elementwiseDivide(suf, denom));
557     double phi = pNorm - Delta;
558     // - sum ( suf**2 / denom**3) / pNorm
559     double phiPrime = - MiscMath.sum(MiscMath.elementwiseDivide(MiscMath.elementwiseMultiply(suf, suf), MiscMath.elementwiseMultiply(MiscMath.elementwiseMultiply(denom, denom), denom))) / pNorm;
560     return new double[]{phi, phiPrime};
561   }
562
563   /**
564   * class holding the result of solveLsqTrustRegion
565   */
566   private class TrustRegion
567   {
568     private double[] step;
569     private double alpha;
570     private int iteration;
571
572     public TrustRegion(double[] step, double alpha, int iteration)
573     {
574       this.step = step;
575       this.alpha = alpha;
576       this.iteration = iteration;
577     }
578
579     public double[] getStep()
580     {
581       return this.step;
582     }
583
584     public double getAlpha()
585     {
586       return this.alpha;
587     }
588   
589     public int getIteration()
590     {
591       return this.iteration;
592     }
593   }
594
595   /**
596   * solve a trust-region problem arising in least-squares optimisation
597   * @param n ~ number of variables
598   * @param m ~ number of residuals
599   * @param uf
600   * @param s ~ singular values of J
601   * @param V ~ transpose of VT
602   * @param Delta ~ radius of a trust region
603   * @param alpha ~ initial guess for alpha
604   *
605   * @return
606   */
607   private TrustRegion solveLsqTrustRegion(int n, int m, double[] uf, double[] s, MatrixI V, double Delta, double alpha)
608   {
609     double[] suf = MiscMath.elementwiseMultiply(s, uf);
610
611     //check if J has full rank and tr Gauss-Newton step
612     boolean fullRank = false;
613     if (m >= n)
614     {
615       double threshold = s[0] * Math.ulp(1.0) * m;
616       fullRank = s[s.length - 1] > threshold;
617     }
618     if (fullRank)
619     {
620       double[] p = MiscMath.elementwiseMultiply(V.sumProduct(MiscMath.elementwiseDivide(uf, s)), -1);
621       if (MiscMath.norm(p) <= Delta)
622       {
623         TrustRegion result = new TrustRegion(p, 0.0, 0);
624         return result;
625       }
626     }
627
628     double alphaUpper = MiscMath.norm(suf) / Delta;
629     double alphaLower = 0.0;
630     if (fullRank)
631     {
632       double[] phiAndPrime = phiAndDerivative(0.0, suf, s, Delta);
633       alphaLower = - phiAndPrime[0] / phiAndPrime[1];
634     }
635
636     alpha = (!fullRank && alpha == 0.0) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
637
638     int iteration = 0;
639     while (iteration < 10)      // 10 is default max_iter
640     {
641       alpha = (alpha < alphaLower || alpha > alphaUpper) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
642       double[] phiAndPrime = phiAndDerivative(alpha, suf, s, Delta);
643       double phi = phiAndPrime[0];
644       double phiPrime = phiAndPrime[1];
645
646       alphaUpper = (phi < 0) ? alpha : alphaUpper;
647       double ratio = phi / phiPrime;
648       alphaLower = Math.max(alphaLower, alpha - ratio);
649       alpha -= (phi + Delta) * ratio / Delta;
650
651       if (Math.abs(phi) < 0.01 * Delta) // default rtol set to 0.01
652       {
653         break;
654       }
655       iteration++;
656     }
657
658     // p = - V.dot( suf / (s**2 + alpha))
659     double[] tmp = MiscMath.elementwiseDivide(suf, MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha));
660     double[] p = MiscMath.elementwiseMultiply(V.sumProduct(tmp), -1);
661
662     // Make the norm of p equal to Delta, p is changed only slightly during this.
663     // It is done to prevent p lie outside of the trust region
664     p = MiscMath.elementwiseMultiply(p, Delta / MiscMath.norm(p));
665
666     TrustRegion result = new TrustRegion(p, alpha, iteration + 1);
667     return result;
668   }
669
670   /**
671   * compute values of a quadratic function arising in least squares
672   * function: 0.5 * s.T * (J.T * J + diag) * s + g.T * s
673   *
674   * @param J ~ jacobian matrix
675   * @param g ~ gradient
676   * @param s ~ steps and rows
677   *
678   * @return
679   */
680   private double evaluateQuadratic(MatrixI J, double[] g, double[] s)
681   {
682
683     double[] Js = J.sumProduct(s);
684     double q = MiscMath.dot(Js, Js);
685     double l = MiscMath.dot(s, g);
686
687     return 0.5 * q + l;
688   }
689
690   /**
691   * update the radius of a trust region based on the cost reduction
692   *
693   * @param Delta
694   * @param actualReduction
695   * @param predictedReduction
696   * @param stepNorm
697   * @param boundHit
698   *
699   * @return
700   */
701   private double[] updateTrustRegionRadius(double Delta, double actualReduction, double predictedReduction, double stepNorm, boolean boundHit)
702   {
703     double ratio = 0;
704     if (predictedReduction > 0)
705     {
706       ratio = actualReduction / predictedReduction;
707     } else if (predictedReduction == 0 && actualReduction == 0) {
708       ratio = 1;
709     } else {
710       ratio = 0;
711     }
712
713     if (ratio < 0.25)
714     {
715       Delta = 0.25 * stepNorm;
716     } else if (ratio > 0.75 && boundHit) {
717       Delta *= 2.0;
718     }
719
720     return new double[]{Delta, ratio};
721   }
722
723   /**
724   * trust region reflective algorithm
725   * @param repMatrix ~ Matrix containing representative vectors
726   * @param scoresOld ~ Matrix containing initial observations
727   * @param index ~ current row index
728   * @param J ~ jacobian matrix
729   *
730   * @return
731   */
732   private double[] trf(MatrixI repMatrix, MatrixI scoresOld, int index, MatrixI J)
733   {
734     //hA = x0
735     double[] hA = repMatrix.getRow(index);
736     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, index);
737     int nfev = 1;
738     int m = J.height();
739     int n = J.width();
740     double cost = 0.5 * MiscMath.dot(f0, f0);
741     double[] g = J.transpose().sumProduct(f0);
742     double Delta = MiscMath.norm(hA);
743     int maxNfev = hA.length * 100;
744     double alpha = 0.0;         // "Levenberg-Marquardt" parameter
745
746     double gNorm = 0;
747     byte terminationStatus = 0;
748     int iteration = 0;
749
750     while (true)
751     {
752       gNorm = MiscMath.norm(g);
753       if (terminationStatus != 0 || nfev == maxNfev)
754       {
755         break;
756       }
757       SingularValueDecomposition svd = new SingularValueDecomposition(new Array2DRowRealMatrix(J.asArray()));
758       MatrixI U = new Matrix(svd.getU().getData());
759       double[] s = svd.getSingularValues();     
760       MatrixI V = new Matrix(svd.getV().getData()).transpose();
761       double[] uf = U.transpose().sumProduct(f0);
762
763       double actualReduction = -1;
764       double[] xNew = new double[hA.length];
765       double[] fNew = new double[f0.length];
766       double costNew = 0;
767       double stepHnorm = 0;
768       
769       while (actualReduction <= 0 && nfev < maxNfev)
770       {
771         TrustRegion trustRegion = solveLsqTrustRegion(n, m, uf, s, V, Delta, alpha);
772         double[] stepH = trustRegion.getStep(); 
773         alpha = trustRegion.getAlpha();
774         int nIterations = trustRegion.getIteration();
775         double predictedReduction = - (evaluateQuadratic(J, g, stepH)); 
776
777         xNew = MiscMath.elementwiseAdd(hA, stepH);
778         fNew = originalToEquasionSystem(xNew, repMatrix, scoresOld, index);
779         nfev++;
780         
781         stepHnorm = MiscMath.norm(stepH);
782
783         if (MiscMath.countNaN(fNew) > 0)
784         {
785           Delta = 0.25 * stepHnorm;
786           continue;
787         }
788
789         // usual trust-region step quality estimation
790         costNew = 0.5 * MiscMath.dot(fNew, fNew); 
791         actualReduction = cost - costNew;
792
793         double[] updatedTrustRegion = updateTrustRegionRadius(Delta, actualReduction, predictedReduction, stepHnorm, stepHnorm > (0.95 * Delta));
794         double DeltaNew = updatedTrustRegion[0];
795         double ratio = updatedTrustRegion[1];
796
797         // default ftol and xtol = 1e-8
798         boolean ftolSatisfied = actualReduction < (1e-8 * cost) && ratio > 0.25;
799         boolean xtolSatisfied = stepHnorm < (1e-8 * (1e-8 + xNorm));
800         terminationStatus = (ftolSatisfied || xtolSatisfied) ? (byte) 1 : (byte) 0;
801         if (terminationStatus != 0)
802         {
803           break;
804         }
805
806         alpha *= Delta / DeltaNew;
807         Delta = DeltaNew;
808
809       }
810       if (actualReduction > 0)
811       {
812         hA = xNew;
813         f0 = fNew;
814         cost = costNew;
815
816         J = approximateDerivative(repMatrix, scoresOld, index);
817
818         g = J.transpose().sumProduct(f0);
819       } else {
820         stepHnorm = 0;
821         actualReduction = 0;
822       }
823       iteration++;
824     }
825
826     return hA;
827   }
828
829   /**
830   * performs the least squares optimisation
831   * adapted from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.least_squares.html#scipy.optimize.least_squares
832   *
833   * @param repMatrix ~ Matrix containing representative vectors
834   * @param scoresOld ~ Matrix containing initial observations
835   * @param index ~ current row index
836   *
837   * @return
838   */
839   private double[] leastSquaresOptimisation(MatrixI repMatrix, MatrixI scoresOld, int index)
840   {
841     MatrixI J = approximateDerivative(repMatrix, scoresOld, index);
842     double[] result = trf(repMatrix, scoresOld, index, J);
843     return result;
844   }
845
846 }