Properly integrated endGaps option for AlignSeq.traceAlignment()
[jalview.git] / src / jalview / analysis / ccAnalysis.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21
22 /*
23 * Copyright 2018-2022 Kathy Su, Kay Diederichs
24
25 * This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
26
27 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
28
29 * You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>. 
30 */
31
32 /**
33 * Ported from https://doi.org/10.1107/S2059798317000699 by
34 * @AUTHOR MorellThomas
35 */
36
37 package jalview.analysis;
38
39 import jalview.bin.Console;
40 import jalview.math.MatrixI;
41 import jalview.math.Matrix;
42 import jalview.math.MiscMath;
43
44 import java.lang.Math;
45 import java.lang.System;
46 import java.util.Arrays;
47 import java.util.ArrayList;
48 import java.util.Comparator;
49 import java.util.HashSet;
50 import java.util.Map.Entry;
51 import java.util.TreeMap;
52
53 import org.apache.commons.math3.analysis.MultivariateVectorFunction;
54 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
55 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
56 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
57 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
58 import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
59 import org.apache.commons.math3.linear.ArrayRealVector;
60 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
61 import org.apache.commons.math3.linear.RealVector;
62 import org.apache.commons.math3.linear.RealMatrix;
63 import org.apache.commons.math3.linear.SingularValueDecomposition;
64
65 /**
66  * A class to model rectangular matrices of double values and operations on them
67  */
68 public class ccAnalysis 
69 {
70   private byte dim = 0;         //dimensions
71
72   private MatrixI scoresOld;    //input scores
73
74   public ccAnalysis(MatrixI scores, byte dim)
75   {
76     //&! round matrix to .4f to be same as in pasimap
77     for (int i = 0; i < scores.height(); i++)
78     {
79       for (int j = 0; j < scores.width(); j++)
80       {
81         if (!Double.isNaN(scores.getValue(i,j)))
82         {
83           scores.setValue(i, j, (double) Math.round(scores.getValue(i,j) * (int) 10000) / 10000);
84         }
85       }
86     }
87     this.scoresOld = scores;
88     this.dim = dim;
89   }
90
91   /** TODO
92   * DOCUMENT ME
93   *
94   * @param hSigns ~ hypothesis signs (+/-) for each sequence
95   * @param scores ~ input score matrix
96   *
97   * @return distrustScores
98   */
99   private int[] initialiseDistrusts(byte[] hSigns, MatrixI scores)
100   {
101     int[] distrustScores = new int[scores.width()];
102     
103     // loop over symmetric matrix
104     for (int i = 0; i < scores.width(); i++)
105     {
106       byte hASign = hSigns[i];
107       int conHypNum = 0;
108       int proHypNum = 0;
109
110       for (int j = 0; j < scores.width(); j++)
111       {
112         double cell = scores.getRow(i)[j];      // value at [i][j] in scores
113         byte hBSign = hSigns[j];
114         if (!Double.isNaN(cell))
115         {
116           byte cellSign = (byte) Math.signum(cell);     //check if sign of matrix value fits hyptohesis
117           if (cellSign == hASign * hBSign)
118           {
119             proHypNum++;
120           } else {
121             conHypNum++;
122           }
123         }
124       }
125       distrustScores[i] = conHypNum - proHypNum;        //create distrust score for each sequence
126     }
127     return distrustScores;
128   }
129
130   /** TODO
131   * DOCUMENT ME
132   *
133   * @param hSigns ~ hypothesis signs (+/-)
134   * @param distrustScores
135   * @param scores ~ input score matrix
136   *
137   * @return hSigns
138   */
139   private byte[] optimiseHypothesis(byte[] hSigns, int[] distrustScores, MatrixI scores)
140   {
141     // get maximum distrust score
142     int[] maxes = MiscMath.findMax(distrustScores);
143     int maxDistrustIndex = maxes[0];
144     int maxDistrust = maxes[1];
145
146     // if hypothesis is not optimal yet
147     if (maxDistrust > 0)
148     {
149       //toggle sign for hI with maximum distrust
150       hSigns[maxDistrustIndex] *= -1;
151       // update distrust at same position
152       distrustScores[maxDistrustIndex] *= -1;
153
154       // also update distrust scores for all hI that were not changed
155       byte hASign = hSigns[maxDistrustIndex];
156       for (int NOTmaxDistrustIndex = 0; NOTmaxDistrustIndex < distrustScores.length; NOTmaxDistrustIndex++)
157       {
158         if (NOTmaxDistrustIndex != maxDistrustIndex)
159         {
160           byte hBSign = hSigns[NOTmaxDistrustIndex];
161           double cell = scores.getRow(maxDistrustIndex)[NOTmaxDistrustIndex];
162
163           // distrust only changed if not NaN
164           if (!Double.isNaN(cell))
165           {
166             byte cellSign = (byte) Math.signum(cell);
167             // if sign of cell matches hypothesis decrease distrust by 2 because 1 more value supporting and 1 less contradicting
168             // else increase by 2
169             if (cellSign == hASign * hBSign)
170             {
171               distrustScores[NOTmaxDistrustIndex] -= 2;
172             } else {
173               distrustScores[NOTmaxDistrustIndex] += 2;
174             }
175           }
176         }
177       }
178       //further optimisation necessary
179       return optimiseHypothesis(hSigns, distrustScores, scores);
180
181     } else {
182       return hSigns;
183     }
184   }
185
186   /** 
187   * takes the a symmetric MatrixI as input scores which may contain Double.NaN 
188   * approximate the missing values using hypothesis optimisation 
189   *
190   * runs analysis
191   *
192   * @param scores ~ score matrix
193   *
194   * @return
195   */
196   public MatrixI run ()
197   {
198     MatrixI eigenMatrix = scoresOld.copy();
199     MatrixI repMatrix = scoresOld.copy();
200     try
201     {
202     /*
203     * Calculate correction factor for 2nd and higher eigenvalue(s).
204     * This correction is NOT needed for the 1st eigenvalue, because the
205     * unknown (=NaN) values of the matrix are approximated by presuming
206     * 1-dimensional vectors as the basis of the matrix interpretation as dot
207     * products.
208     */
209         
210     //&! debug
211     System.out.println("input:");
212     eigenMatrix.print(System.out, "%1.4f ");
213     int matrixWidth = eigenMatrix.width(); // square matrix, so width == height
214     int matrixElementsTotal = (int) Math.pow(matrixWidth, 2);   //total number of elemts
215
216     float correctionFactor = (float) (matrixElementsTotal - eigenMatrix.countNaN()) / (float) matrixElementsTotal;
217     
218     /*
219     * Calculate hypothetical value (1-dimensional vector) h_i for each
220     * dataset by interpreting the given correlation coefficients as scalar
221     * products.
222     */
223
224     /*
225     * Memory for current hypothesis concerning sign of each h_i.
226     * List of signs for all h_i in the encoding:
227       * *  1: positive
228       * *  0: zero
229       * * -1: negative
230     * Initial hypothesis: all signs are positive.
231     */
232     byte[] hSigns = new byte[matrixWidth];
233     Arrays.fill(hSigns, (byte) 1);
234
235     //Estimate signs for each h_i by refining hypothesis on signs.
236     hSigns = optimiseHypothesis(hSigns, initialiseDistrusts(hSigns, eigenMatrix), eigenMatrix);
237
238
239     //Estimate absolute values for each h_i by determining sqrt of mean of
240     //non-NaN absolute values for every row.
241     //<++> hAbs = //np.sqrt(np.nanmean(np.absolute(eigenMatrix), axis=1)) 
242     MiscMath.print(eigenMatrix.absolute().meanRow(), "%1.8f");
243     double[] hAbs = MiscMath.sqrt(eigenMatrix.absolute().meanRow()); //np.sqrt(np.nanmean(np.absolute(eigenMatrix), axis=1))
244
245     //Combine estimated signs with absolute values in obtain total value for
246     //each h_i.
247     double[] hValues = MiscMath.elementwiseMultiply(hSigns, hAbs);
248     //<++>hValues.reshape((1,matrixWidth));     // doesnt it already look like this
249
250     /*Complement symmetric matrix by using the scalar products of estimated
251     *values of h_i to replace NaN-cells.
252     *Matrix positions that have estimated values
253     *(only for diagonal and upper off-diagonal values, due to the symmetry
254     *the positions of the lower-diagonal values can be inferred).
255     List of tuples (row_idx, column_idx).*/
256
257     ArrayList<int[]> estimatedPositions = new ArrayList<int[]>();
258
259     // for off-diagonal cells
260     for (int rowIndex = 0; rowIndex < matrixWidth - 1; rowIndex++)
261     {
262       for (int columnIndex = rowIndex + 1; columnIndex < matrixWidth; columnIndex++)
263       {
264         double cell = eigenMatrix.getValue(rowIndex, columnIndex);
265         if (Double.isNaN(cell))
266         {
267           //calculate scalar product as new cell value
268           cell = hValues[rowIndex] * hValues[columnIndex];      // something is wrong with hAbs andhValues and everything!!!!!!!!!!!!
269           //fill in new value in cell and symmetric partner
270           eigenMatrix.setValue(rowIndex, columnIndex, cell);
271           eigenMatrix.setValue(columnIndex, rowIndex, cell);
272           //save positions of estimated values
273           estimatedPositions.add(new int[]{rowIndex, columnIndex});
274         }
275       }
276     }
277
278     // for diagonal cells
279     for (int diagonalIndex = 0; diagonalIndex < matrixWidth; diagonalIndex++)
280       {
281         double cell = Math.pow(hValues[diagonalIndex], 2);
282         eigenMatrix.setValue(diagonalIndex, diagonalIndex, cell);
283         estimatedPositions.add(new int[]{diagonalIndex, diagonalIndex});
284       }
285
286     /*Refine total values of each h_i:
287     *Initialise h_values of the hypothetical non-existant previous iteration
288     *with the correct format but with impossible values.
289      Needed for exit condition of otherwise endless loop.*/
290     System.out.print("initial values: [ ");
291     for (double h : hValues)
292     {
293       System.out.print(String.format("%1.4f, ", h));
294     }
295     System.out.println(" ]");
296
297
298     double[] hValuesOld = new double[matrixWidth];
299
300     int iterationCount = 0;
301
302     // repeat unitl values of h do not significantly change anymore
303     while (true)
304     {
305       for (int hIndex = 0; hIndex < matrixWidth; hIndex++)
306       {
307         //@python newH = np.sum(hValues * eigenMatrix[hIndex]) / np.sum(hValues ** 2)
308         double newH = Arrays.stream(MiscMath.elementwiseMultiply(hValues, eigenMatrix.getRow(hIndex))).sum() / Arrays.stream(MiscMath.elementwiseMultiply(hValues, hValues)).sum();
309         hValues[hIndex] = newH;
310       }
311
312       System.out.print(String.format("iteration %d: [ ", iterationCount));
313       for (double h : hValues)
314       {
315         System.out.print(String.format("%1.4f, ", h));
316       }
317       System.out.println(" ]");
318
319       //update values of estimated positions
320       for (int[] pair : estimatedPositions)     // pair ~ row, col
321       {
322         double newVal = hValues[pair[0]] * hValues[pair[1]];
323         eigenMatrix.setValue(pair[0], pair[1], newVal);
324         eigenMatrix.setValue(pair[1], pair[0], newVal);
325       }
326
327       iterationCount++;
328
329       //exit loop as soon as new values are similar to the last iteration
330       if (MiscMath.allClose(hValues, hValuesOld, 0d, 1e-05d, false))
331       {
332         break;
333       }
334
335       //save hValues for comparison in the next iteration
336       System.arraycopy(hValues, 0, hValuesOld, 0, hValues.length);
337     }
338
339     //-----------------------------
340     //Use complemented symmetric matrix to calculate final representative
341     //vectors.
342     //&! debug
343     System.out.println("after estimating:");
344     eigenMatrix.print(System.out, "%1.8f ");
345
346     //Eigendecomposition.
347     eigenMatrix.tred();
348     System.out.println("tred");
349     eigenMatrix.print(System.out, "%8.2f");
350
351     eigenMatrix.tqli();
352     System.out.println("eigenvals");
353     eigenMatrix.printD(System.out, "%2.4f ");
354     System.out.println();
355     System.out.println("tqli");
356     eigenMatrix.print(System.out, "%8.2f");
357
358     double[] eigenVals = eigenMatrix.getD();
359
360     /*
361     TreeMap<Double, double[]> eigenPairs = new TreeMap<>(Comparator.reverseOrder());
362     for (int i = 0; i < eigenVals.length; i++)
363     {
364       eigenPairs.put(eigenVals[i], eigenMatrix.getColumn(i));
365     }
366     */
367     TreeMap<Double, Integer> eigenPairs = new TreeMap<>(Comparator.reverseOrder());
368     for (int i = 0; i < eigenVals.length; i++)
369     {
370       eigenPairs.put(eigenVals[i], i);
371     }
372
373     // matrix of representative eigenvectors (each row is a vector)
374     double[][] _repMatrix = new double[eigenVals.length][dim];  //last ones were dim
375     double[][] _oldMatrix = new double[eigenVals.length][dim];
376     double[] correctedEigenValues = new double[dim];    
377
378     //for (Entry<Double, double[]> pair : eigenPairs.entrySet())
379     int l = 0;
380     for (Entry<Double, Integer> pair : eigenPairs.entrySet())
381     {
382       double eigenValue = pair.getKey();
383       int column = pair.getValue();
384       double[] eigenVector = eigenMatrix.getColumn(column);
385       //for 2nd and higher eigenvalues
386       if (l >= 1)
387       {
388         eigenValue /= correctionFactor;
389       }
390       //l++;
391       //correctedEigenValues[dim - l] = eigenValue;
392       correctedEigenValues[l] = eigenValue;
393       for (int j = 0; j < eigenVector.length; j++)
394       {
395         //_repMatrix[j][dim - l] = (eigenValue < 0) ? 0.0 : Math.sqrt(eigenValue) * eigenVector[j];
396         _repMatrix[j][l] = (eigenValue < 0) ? 0.0 : - Math.sqrt(eigenValue) * eigenVector[j];
397         double tmpOldScore = scoresOld.getColumn(column)[j];
398         _oldMatrix[j][dim - l - 1] = (Double.isNaN(tmpOldScore)) ? 0.0 : tmpOldScore;
399       }
400       l++;
401       if (l >= dim)
402       {
403         break;
404       }
405     }
406
407     System.out.println("correctedEigenValues");
408     MiscMath.print(correctedEigenValues, "%2.4f ");
409
410     repMatrix = new Matrix(_repMatrix);
411     repMatrix.setD(correctedEigenValues);
412     MatrixI oldMatrix = new Matrix(_oldMatrix); //TODO do i even need it anymore?
413
414     System.out.println("old matrix");
415     oldMatrix.print(System.out, "%8.2f");
416
417     System.out.println("scoresOld");
418     scoresOld.print(System.out, "%1.4f ");
419
420     System.out.println("rep matrix");
421     repMatrix.print(System.out, "%1.8f ");
422
423     MatrixI dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
424     System.out.println("dot matrix");
425     dotMatrix.print(System.out, "%1.8f ");
426     
427     double rmsd = scoresOld.rmsd(dotMatrix);    //TODO do i need this here?
428     System.out.println(rmsd);   
429
430     System.out.println("iteration, rmsd, maxDiff, rmsdDiff");
431     System.out.println(String.format("0, %8.5f, -, -", rmsd));
432     // Refine representative vectors by minimising sum-of-squared deviates between dotMatrix and original  score matrix
433     for (int iteration = 1; iteration < 21; iteration++)        // arbitrarily set to 20
434     {
435       MatrixI repMatrixOLD = repMatrix.copy();
436       MatrixI dotMatrixOLD = dotMatrix.copy();
437
438       // for all rows/hA in the original matrix
439       for (int hAIndex = 0; hAIndex < oldMatrix.height(); hAIndex++)
440       {
441         double[] row = oldMatrix.getRow(hAIndex);
442         double[] hA = repMatrix.getRow(hAIndex);        // inverted
443         hAIndex = hAIndex;
444         //find least-squares-solution fo rdifferences between original scores and representative vectors
445         //--> originalToEquasionSystem(hA, hAIndex, repMatrix, row) --> double[]
446         System.out.println(String.format("||||||||||||||||||||||||||||||||||||||||||||\nIteration: %d", iteration));
447         //repMatrix =  new Matrix( new double[][]{{ 0.92894902, -0.25013783, -0.0051076 }, { 0.91955135, -0.25024707, -0.00516568}, { 0.90957348, -0.21717002, -0.14259899}, { 0.90298063, -0.21678816, -0.14697814}, { 0.9157065,  -0.04646437,  0.10105454}, { 0.9050301,  -0.04689785,  0.18964432}, { 0.92498545,  0.03881933,  0.14771523}, { 0.88008842,  0.0142395,   0.11356242}, { 0.94276528,  0.25591474, -0.07190911}, { 0.93939976,  0.23375136, -0.07530558}, { 0.93550782,  0.24635804, -0.07317563}, { 0.92497731,  0.21729928, -0.02361162}});
448         //scoresOld = new Matrix( new double[][]{{   Double.NaN, 0.9914, 0.8879, 0.8803, 0.8528, 0.8481, 0.8434, 0.8174, 0.8174, 0.8232, 0.8148, 0.8114}, {0.9914,    Double.NaN, 0.8792, 0.8717, 0.8441, 0.8395, 0.8347, 0.8087, 0.8085, 0.8143, 0.8059, 0.8024}, {0.8879, 0.8792,    Double.NaN, 0.9578, 0.8289, 0.8075, 0.8268, 0.7815, 0.8165, 0.8168, 0.8084, 0.7974}, {0.8803, 0.8717, 0.9578,    Double.NaN, 0.838,  0.7974, 0.8058, 0.7798, 0.8094, 0.8098, 0.8067, 0.7905}, {0.8528, 0.8441, 0.8289, 0.838,     Double.NaN, 0.8954, 0.8412, 0.7879, 0.8418, 0.8384, 0.8389, 0.8556}, {0.8481, 0.8395, 0.8075, 0.7974, 0.8954,    Double.NaN, 0.8699, 0.8106, 0.8267, 0.8234, 0.8222, 0.8241}, {0.8434, 0.8347, 0.8268, 0.8058, 0.8412, 0.8699,    Double.NaN, 0.869,  0.8745, 0.8583, 0.8661, 0.8593}, {0.8174, 0.8087, 0.7815, 0.7798, 0.7879, 0.8106, 0.869,     Double.NaN, 0.8273, 0.8331, 0.8137, 0.8029}, {0.8174, 0.8085, 0.8165, 0.8094, 0.8418, 0.8267, 0.8745, 0.8273,    Double.NaN, 0.967, 0.978,  0.9373}, {0.8232, 0.8143, 0.8168, 0.8098, 0.8384, 0.8234, 0.8583, 0.8331, 0.967,     Double.NaN, 0.9561, 0.9337}, {0.8148, 0.8059, 0.8084, 0.8067, 0.8389, 0.8222, 0.8661, 0.8137, 0.978,  0.9561, Double.NaN, 0.9263}, {0.8114, 0.8024, 0.7974, 0.7905, 0.8556, 0.8241, 0.8593, 0.8029, 0.9373, 0.9337, 0.9263,    Double.NaN}});
449         double[] hAlsm = leastSquaresOptimisation(repMatrix, scoresOld, hAIndex);
450         // update repMatrix with new hAlsm
451         for (int j = 0; j < repMatrix.width(); j++)
452         {
453           repMatrix.setValue(hAIndex, j, hAlsm[j]);
454         }
455         break;
456       }
457       
458       // dot product of representative vecotrs yields a matrix with values approximating the correlation matrix
459       dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
460       // calculate rmsd between approximation and correlation matrix
461       rmsd = scoresOld.rmsd(dotMatrix);
462
463       // calculate maximum change of representative vectors of current iteration
464       //repMatrix.subtract(repMatrixOLD).print(System.out, "%8.2f");
465       MatrixI diff = repMatrix.subtract(repMatrixOLD).absolute();
466       double maxDiff = 0.0;
467       for (int i = 0; i < diff.height(); i++)
468       {
469         for (int j = 0; j < diff.width(); j++)
470         {
471           maxDiff = (diff.getValue(i, j) > maxDiff) ? diff.getValue(i, j) : maxDiff;
472         }
473       }
474       System.out.println(String.format("maxDiff: %f", maxDiff));
475
476       // calculate rmsd between current and previous estimation
477       double rmsdDiff = dotMatrix.rmsd(dotMatrixOLD);
478
479       System.out.println(String.format("%d, %8.5f, %8.5f, %8.5f", iteration, rmsd, maxDiff, rmsdDiff));
480
481       if (!(Math.abs(maxDiff) > 1e-06))
482       {
483         repMatrix = repMatrixOLD.copy();
484         break;
485       }
486       break;
487     }
488     
489
490     } catch (Exception q)
491     {
492       Console.error("Error computing cc_analysis:  " + q.getMessage());
493       q.printStackTrace();
494     }
495     //repMatrix = repMatrix.transpose();
496     System.out.println("final repMatrix");
497     repMatrix.print(System.out, "%8.2f");
498     return repMatrix;
499   }
500
501   /**
502   * Create equations system using information on originally known
503   * pairwise correlation coefficients (parsed from infile) and the
504   * representative result vectors
505   *
506   * Each equation has the format:
507   * hA * hA - pairwiseCC = 0
508   * with:
509   * hA: unknown variable
510   * hB: known representative vector
511   * pairwiseCC: known pairwise correlation coefficien
512   * 
513   * The resulting equations system is overdetermined, if there are more
514   * equations than unknown elements
515   *
516   * @param x ~ unknown n-dimensional column-vector
517   * (needed for generating equations system, NOT to be specified by user).
518   * @param hAIndex ~ index of currently optimised representative result vector.
519   * @param h ~ matrix with row-wise listing of representative result vectors.
520   * @param originalRow ~ matrix-row of originally parsed pairwise correlation coefficients.
521   *
522   * @return
523   */
524   //private double[] originalToEquasionSystem(double[] x, int hAIndex, MatrixI h, MatrixI originalScores)
525   private double[] originalToEquasionSystem(double[] hA, MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
526   {
527     double[] originalRow = scoresOld.getRow(hAIndex);
528     int nans = MiscMath.countNaN(originalRow);
529     double[] result = new double[originalRow.length - nans];
530
531     //for all pairwiseCC in originalRow
532     int resultIndex = 0;
533     for (int hBIndex = 0; hBIndex < originalRow.length; hBIndex++)
534     {
535       double pairwiseCC = originalRow[hBIndex];
536       // if not NaN -> create new equation and add it to the system
537       if (!Double.isNaN(pairwiseCC))
538       {
539         double[] hB = repMatrix.getRow(hBIndex);
540         result[resultIndex++] = MiscMath.sum(MiscMath.elementwiseMultiply(hA, hB)) - pairwiseCC;
541       } else {
542       }
543     }
544     return result;
545   }
546
547   /**
548   * returns the jacobian matrix
549   * @param repMatrix ~ matrix of representative vectors
550   * @param hAIndex ~ current row index
551   *
552   * @return
553   */
554   private MatrixI approximateDerivative(MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
555   {
556     //hA = x0
557     double[] hA = repMatrix.getRow(hAIndex);
558     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, hAIndex);
559     System.out.println("Approximate derivative with ");
560     System.out.print("hA (x): ");
561     MiscMath.print(hA, "%1.8f");
562     System.out.print("f0: ");
563     MiscMath.print(f0, "%1.8f");
564     double[] signX0 = new double[hA.length];
565     double[] xAbs = new double[hA.length];
566     for (int i = 0; i < hA.length; i++)
567     {
568       signX0[i] = (hA[i] >= 0) ? 1 : -1;
569       xAbs[i] = (Math.abs(hA[i]) >= 1.0) ? Math.abs(hA[i]) : 1.0;
570       }
571     double rstep = Math.pow(Math.ulp(1.0), 0.5);
572
573     double[] h = new double [hA.length];
574     for (int i = 0; i < hA.length; i++)
575     {
576       h[i] = rstep * signX0[i] * xAbs[i];
577     }
578       
579     int m = f0.length;
580     int n = hA.length;
581     double[][] jTransposed = new double[n][m];
582     for (int i = 0; i < h.length; i++)
583     {
584       double[] x = new double[h.length];
585       System.arraycopy(hA, 0, x, 0, h.length);
586       x[i] += h[i];
587       double dx = x[i] - hA[i];
588       double[] df = originalToEquasionSystem(x, repMatrix, scoresOld, hAIndex);
589       for (int j = 0; j < df.length; j++)
590       {
591         df[j] -= f0[j];
592         jTransposed[i][j] = df[j] / dx;
593       }
594     }
595     MatrixI J = new Matrix(jTransposed).transpose();    // inverted
596     return J;
597   }
598
599   /**
600   * norm of regularized (by alpha) least-squares solution minus Delta
601   * @param alpha
602   * @param suf
603   * @param s
604   * @param Delta
605   *
606   * @return
607   */
608   private double[] phiAndDerivative(double alpha, double[] suf, double[] s, double Delta)
609   {
610     double[] denom = MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha);
611     double pNorm = MiscMath.norm(MiscMath.elementwiseDivide(suf, denom));
612     double phi = pNorm - Delta;
613     // - sum ( suf**2 / denom**3) / pNorm
614     double phiPrime = - MiscMath.sum(MiscMath.elementwiseDivide(MiscMath.elementwiseMultiply(suf, suf), MiscMath.elementwiseMultiply(MiscMath.elementwiseMultiply(denom, denom), denom))) / pNorm;
615     return new double[]{phi, phiPrime};
616   }
617
618   /**
619   * class holding the result of solveLsqTrustRegion
620   */
621   private class TrustRegion
622   {
623     private double[] step;
624     private double alpha;
625     private int iteration;
626
627     public TrustRegion(double[] step, double alpha, int iteration)
628     {
629       this.step = step;
630       this.alpha = alpha;
631       this.iteration = iteration;
632     }
633
634     public double[] getStep()
635     {
636       return this.step;
637     }
638
639     public double getAlpha()
640     {
641       return this.alpha;
642     }
643   
644     public int getIteration()
645     {
646       return this.iteration;
647     }
648   }
649
650   /**
651   * solve a trust-region problem arising in least-squares optimisation
652   * @param n ~ number of variables
653   * @param m ~ number of residuals
654   * @param uf ~ <++>
655   * @param s ~ singular values of J
656   * @param V ~ transpose of VT
657   * @param Delta ~ radius of a trust region
658   * @param alpha ~ initial guess for alpha
659   *
660   * @return
661   */
662   private TrustRegion solveLsqTrustRegion(int n, int m, double[] uf, double[] s, MatrixI V, double Delta, double alpha)
663   {
664     double[] suf = MiscMath.elementwiseMultiply(s, uf);
665
666     //check if J has full rank and tr Gauss-Newton step
667     boolean fullRank = false;
668     if (m >= n)
669     {
670       double threshold = s[0] * Math.ulp(1.0) * m;
671       fullRank = s[s.length - 1] > threshold;
672     }
673     if (fullRank)
674     {
675       double[] p = MiscMath.elementwiseMultiply(V.sumProduct(MiscMath.elementwiseDivide(uf, s)), -1);   // inverted and roughly fine
676       if (MiscMath.norm(p) <= Delta)
677       {
678         TrustRegion result = new TrustRegion(p, 0.0, 0);
679         return result;
680       }
681     }
682
683     double alphaUpper = MiscMath.norm(suf) / Delta;
684     double alphaLower = 0.0;
685     if (fullRank)
686     {
687       double[] phiAndPrime = phiAndDerivative(0.0, suf, s, Delta);
688       alphaLower = - phiAndPrime[0] / phiAndPrime[1];
689     }
690
691     alpha = (!fullRank && alpha == 0.0) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
692
693     int iteration = 0;
694     while (iteration < 10)      // 10 is default max_iter
695     {
696       alpha = (alpha < alphaLower || alpha > alphaUpper) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
697       double[] phiAndPrime = phiAndDerivative(alpha, suf, s, Delta);
698       double phi = phiAndPrime[0];
699       double phiPrime = phiAndPrime[1];
700
701       alphaUpper = (phi < 0) ? alpha : alphaUpper;
702       double ratio = phi / phiPrime;
703       alphaLower = Math.max(alphaLower, alpha - ratio);
704       alpha -= (phi + Delta) * ratio / Delta;
705
706       if (Math.abs(phi) < 0.01 * Delta) // default rtol set to 0.01
707       {
708         break;
709       }
710       iteration++;
711     }
712
713     // p = - V.dot( suf / (s**2 + alpha))
714     double[] tmp = MiscMath.elementwiseDivide(suf, MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha));
715     double[] p = MiscMath.elementwiseMultiply(V.sumProduct(tmp), -1);
716
717     // Make the norm of p equal to Delta, p is changed only slightly during this.
718     // It is done to prevent p lie outside of the trust region
719     p = MiscMath.elementwiseMultiply(p, Delta / MiscMath.norm(p));
720
721     TrustRegion result = new TrustRegion(p, alpha, iteration + 1);
722     return result;
723   }
724
725   /**
726   * compute values of a quadratic function arising in least squares
727   * function: 0.5 * s.T * (J.T * J + diag) * s + g.T * s
728   *
729   * @param J ~ jacobian matrix
730   * @param g ~ gradient
731   * @param s ~ steps and rows
732   *
733   * @return
734   */
735   private double evaluateQuadratic(MatrixI J, double[] g, double[] s)
736   {
737
738     //TODO s (-> stepH) is slightly different
739     double[] Js = J.sumProduct(s);      //TODO completely wromg
740     double q = MiscMath.dot(Js, Js);
741     double l = MiscMath.dot(s, g);
742
743     /*
744     System.out.println("doing evaluateQuadratic");
745     System.out.println("inputs");
746     System.out.print("J");
747     J.print(System.out, "%f ");
748     System.out.print("g");
749     MiscMath.print(g, "%f");
750     System.out.print("s");
751     MiscMath.print(s, "%f");
752     System.out.print("\nJs");
753     MiscMath.print(Js, "%f");
754     System.out.println(String.format("0.5 * %f + %f", q, l));
755     */
756
757     return 0.5 * q + l;
758   }
759
760   /**
761   * update the radius of a trust region based on the cost reduction
762   *
763   * @param Delta
764   * @param actualReduction
765   * @param predictedReduction
766   * @param stepNorm
767   * @param boundHit
768   *
769   * @return
770   */
771   private double[] updateTrustRegionRadius(double Delta, double actualReduction, double predictedReduction, double stepNorm, boolean boundHit)
772   {
773     double ratio = 0;
774     if (predictedReduction > 0)
775     {
776       ratio = actualReduction / predictedReduction;
777     } else if (predictedReduction == 0 && actualReduction == 0) {
778       ratio = 1;
779     } else {
780       ratio = 0;
781     }
782
783     if (ratio < 0.25)
784     {
785       Delta = 0.25 * stepNorm;
786     } else if (ratio > 0.75 && boundHit) {
787       Delta *= 2.0;
788     }
789
790     return new double[]{Delta, ratio};
791   }
792
793   /**
794   * check the termination condition for nonlinear least squares
795   * TODO can be removed and added just as one line in trf (: terminationStatus = (ftolSatisfied condition || xtolSatisfied condition) ? 1 : 0;)  doesnt matter as long as distinguished between 0 and rest
796   *
797   * @param actualReduction
798   * @param cost
799   * @param stepNorm
800   * @param xNorm
801   * @param ratio
802   *
803   * @return
804   */
805   private byte checkTermination(double actualReduction, double cost, double stepNorm, double xNorm, double ratio)
806   {
807     // default ftol and xtol = 1e-8
808     boolean ftolSatisfied = actualReduction < (1e-8 * cost) && ratio > 0.25;
809     boolean xtolSatisfied = stepNorm < (1e-8 * (1e-8 + xNorm));
810
811     if (ftolSatisfied && xtolSatisfied)
812     {
813       return (byte) 4;
814     } else if (ftolSatisfied) {
815       return (byte) 2;
816     } else if (xtolSatisfied) {
817       return (byte) 3;
818     } else {
819       return (byte) 0;
820     }
821   }
822
823   /**
824   * TODO DOCUMENT ME!
825   * @param repMatrix ~ Matrix containing representative vectors
826   * @param scoresOld ~ Matrix containing initial observations
827   * @param index ~ current row index
828   * @param J ~ jacobian matrix
829   *
830   * @return
831   */
832   private double[] trf(MatrixI repMatrix, MatrixI scoresOld, int index, MatrixI J)
833   {
834     System.out.println("-----------------\nStart of trf");
835     //hA = x0
836     double[] hA = repMatrix.getRow(index);      //inverted
837     double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, index);
838     int nfev = 1;       // ??
839     int njev = 1;       // ??
840     int m = J.height();
841     int n = J.width();
842     double cost = 0.5 * MiscMath.dot(f0, f0);
843     double[] g = J.transpose().sumProduct(f0);  // inverted
844     double[] scale = new double[hA.length];
845     Arrays.fill(scale, 1);              // ??
846     double Delta = MiscMath.norm(hA);
847     int maxNfev = hA.length * 100;      // ??
848     double alpha = 0.0;         // ?? "Levenberg-Marquardt" parameter
849
850     System.out.println("Checking initial values:");
851     System.out.print("hA (x): ");
852     MiscMath.print(hA, "%1.8f");
853     System.out.print("f0: ");
854     MiscMath.print(f0, "%1.8f");
855     //System.out.println(String.format("nfev: %d, njev: %d, maxNfev: %d", nfev, njev, maxNfev));
856     //System.out.println(String.format("m: %d, n: %d", m, n));
857     System.out.println(String.format("cost: %1.8f, Delta: %1.8f, alpha: %1.8f", cost, Delta, alpha));
858     System.out.print("g: ");
859     MiscMath.print(g, "%1.8f");
860
861     double gNorm = 0;
862     byte terminationStatus = 0;
863     int iteration = 0;
864
865     System.out.println("outer while loop starts");
866     while (true)
867     {
868       System.out.println(String.format("iteration: %d", iteration));
869
870       gNorm = MiscMath.norm(g);
871       if (terminationStatus != 0 || nfev == maxNfev)
872       {
873         System.out.println(String.format("outer loop broken with terminationStatus: %d and nfev: %d", terminationStatus, nfev));
874         break;
875       }
876       // d = scale
877       SingularValueDecomposition svd = new SingularValueDecomposition(new Array2DRowRealMatrix(J.asArray()));
878       // svd not 100% correct -> origin of problems
879       MatrixI U = new Matrix(svd.getU().getData());     // inverted
880       double[] s = svd.getSingularValues();             // ??
881       MatrixI V = new Matrix(svd.getV().getData()).transpose(); //TODO inverted origin of probelm
882       double[] uf = U.transpose().sumProduct(f0);
883
884       System.out.println("After SVD");
885       System.out.println(String.format("gNorm: %1.8f", gNorm));
886       System.out.print("U: ");
887       U.print(System.out, "%1.8f ");
888       System.out.print("s: ");
889       MiscMath.print(s, "%1.8f");
890       System.out.print("V: ");
891       V.print(System.out, "%1.8f ");
892       System.out.print("uf: ");
893       MiscMath.print(uf, "%1.8f");
894
895       double actualReduction = -1;
896       double[] xNew = new double[hA.length];
897       double[] fNew = new double[f0.length];
898       double costNew = 0;
899       double stepHnorm = 0;
900       
901       System.out.println("Inner while loop starts");
902
903       while (actualReduction <= 0 && nfev < maxNfev)
904       {
905         TrustRegion trustRegion = solveLsqTrustRegion(n, m, uf, s, V, Delta, alpha);
906         double[] stepH = trustRegion.getStep(); 
907         alpha = trustRegion.getAlpha();
908         int nIterations = trustRegion.getIteration();
909         double predictedReduction = - (evaluateQuadratic(J, g, stepH)); 
910
911         xNew = MiscMath.elementwiseAdd(hA, stepH);
912         fNew = originalToEquasionSystem(xNew, repMatrix, scoresOld, index);
913         nfev++;
914         
915         stepHnorm = MiscMath.norm(stepH);
916
917         System.out.println("After TrustRegion");
918         System.out.print("stepH: ");
919         MiscMath.print(stepH, "%1.8f ");
920         System.out.println(String.format("alpha: %1.8f, nIterations: %d, predictedReduction: %1.8f, nfev: %d, stepHnorm: %1.8f", alpha, nIterations, predictedReduction, nfev, stepHnorm));
921         System.out.print("xNew: ");
922         MiscMath.print(xNew, "%1.8f ");
923         System.out.print("fNew: ");
924         MiscMath.print(fNew, "%1.8f ");
925
926         if (MiscMath.countNaN(fNew) > 0)
927         {
928           Delta = 0.25 * stepHnorm;
929           System.out.println(String.format("Loop continued with %d NaNs and Delta: %1.8f", MiscMath.countNaN(fNew), Delta));
930           continue;
931         }
932
933         // usual trust-region step quality estimation
934         costNew = 0.5 * MiscMath.dot(fNew, fNew); 
935         actualReduction = cost - costNew;
936
937         double[] updatedTrustRegion = updateTrustRegionRadius(Delta, actualReduction, predictedReduction, stepHnorm, stepHnorm > (0.95 * Delta));
938         double DeltaNew = updatedTrustRegion[0];
939         double ratio = updatedTrustRegion[1];
940
941         terminationStatus = checkTermination(actualReduction, cost, stepHnorm, MiscMath.norm(hA), ratio);
942         if (terminationStatus != 0)
943         {
944           break;
945         }
946
947         alpha *= Delta / DeltaNew;
948         Delta = DeltaNew;
949
950         System.out.println(String.format("actualReduction: %1.8f, alpha: %1.8f, Delta: %1.8f, termination_status: %d, cost_new: %1.8f", actualReduction, alpha, Delta, terminationStatus, costNew));
951         //break;
952       }
953       System.out.println(String.format("actualReduction before check: %1.8f", actualReduction));
954       if (actualReduction > 0)
955       {
956         hA = xNew;
957         f0 = fNew;
958         cost = costNew;
959
960         J = approximateDerivative(repMatrix, scoresOld, index);
961         System.out.println("J in the end");
962         J.print(System.out, "%1.8f ");
963         njev++;
964
965         g = J.transpose().sumProduct(f0);
966       } else {
967         stepHnorm = 0;
968         actualReduction = 0;
969       }
970       iteration++;
971     }
972
973     System.out.println("into OptimizeResult");
974     System.out.println("x (hA)");
975     MiscMath.print(hA, "%1.8f");
976     System.out.println(String.format("cost: %1.8f", cost));
977     System.out.println("f0");
978     MiscMath.print(f0, "%1.8f");
979     System.out.println("J");
980     J.print(System.out, "%1.8f ");
981     System.out.println("g");
982     MiscMath.print(g, "%1.8f");
983     System.out.println(String.format("gNorm: %1.8f", gNorm));
984     System.out.println(String.format("nfev: %d", nfev));
985     System.out.println(String.format("njev: %d", njev));
986     System.out.println(String.format("terminationStatus: %d", terminationStatus));
987     // OptimizeResult(x, cost, f0, J, g, gNorm, 0 in shape x, nfev, njev, terminationStatus)
988     return hA;
989   }
990
991   /**
992   * TODO DOCUMENT ME!
993   * @param repMatrix ~ Matrix containing representative vectors
994   * @param scoresOld ~ Matrix containing initial observations
995   * @param index ~ current row index
996   *
997   * @return
998   */
999   private double[] leastSquaresOptimisation(MatrixI repMatrix, MatrixI scoresOld, int index)
1000   {
1001     System.out.println("lsq starts!!!");
1002     MatrixI J = approximateDerivative(repMatrix, scoresOld, index);
1003     System.out.println("J");
1004     J.print(System.out, "%1.8f ");
1005     double[] result = trf(repMatrix, scoresOld, index, J);
1006     return result;
1007   }
1008
1009 }