2 * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3 * Copyright (C) $$Year-Rel$$ The Jalview Authors
5 * This file is part of Jalview.
7 * Jalview is free software: you can redistribute it and/or
8 * modify it under the terms of the GNU General Public License
9 * as published by the Free Software Foundation, either version 3
10 * of the License, or (at your option) any later version.
12 * Jalview is distributed in the hope that it will be useful, but
13 * WITHOUT ANY WARRANTY; without even the implied warranty
14 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15 * PURPOSE. See the GNU General Public License for more details.
17 * You should have received a copy of the GNU General Public License
18 * along with Jalview. If not, see <http://www.gnu.org/licenses/>.
19 * The Jalview Authors are detailed in the 'AUTHORS' file.
23 * Copyright 2018-2022 Kathy Su, Kay Diederichs
25 * This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
27 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
29 * You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
33 * Ported from https://doi.org/10.1107/S2059798317000699 by
34 * @AUTHOR MorellThomas
37 package jalview.analysis;
39 import jalview.bin.Console;
40 import jalview.math.MatrixI;
41 import jalview.math.Matrix;
42 import jalview.math.MiscMath;
44 import java.lang.Math;
45 import java.lang.System;
46 import java.util.Arrays;
47 import java.util.ArrayList;
48 import java.util.Comparator;
49 import java.util.HashSet;
50 import java.util.Map.Entry;
51 import java.util.TreeMap;
53 import org.apache.commons.math3.analysis.MultivariateVectorFunction;
54 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
55 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
56 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
57 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
58 import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
59 import org.apache.commons.math3.linear.ArrayRealVector;
60 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
61 import org.apache.commons.math3.linear.RealVector;
62 import org.apache.commons.math3.linear.RealMatrix;
63 import org.apache.commons.math3.linear.SingularValueDecomposition;
66 * A class to model rectangular matrices of double values and operations on them
68 public class ccAnalysis
70 private byte dim = 0; //dimensions
72 private MatrixI scoresOld; //input scores
74 public ccAnalysis(MatrixI scores, byte dim)
76 //&! round matrix to .4f to be same as in pasimap
77 for (int i = 0; i < scores.height(); i++)
79 for (int j = 0; j < scores.width(); j++)
81 if (!Double.isNaN(scores.getValue(i,j)))
83 scores.setValue(i, j, (double) Math.round(scores.getValue(i,j) * (int) 10000) / 10000);
87 this.scoresOld = scores;
94 * @param hSigns ~ hypothesis signs (+/-) for each sequence
95 * @param scores ~ input score matrix
97 * @return distrustScores
99 private int[] initialiseDistrusts(byte[] hSigns, MatrixI scores)
101 int[] distrustScores = new int[scores.width()];
103 // loop over symmetric matrix
104 for (int i = 0; i < scores.width(); i++)
106 byte hASign = hSigns[i];
110 for (int j = 0; j < scores.width(); j++)
112 double cell = scores.getRow(i)[j]; // value at [i][j] in scores
113 byte hBSign = hSigns[j];
114 if (!Double.isNaN(cell))
116 byte cellSign = (byte) Math.signum(cell); //check if sign of matrix value fits hyptohesis
117 if (cellSign == hASign * hBSign)
125 distrustScores[i] = conHypNum - proHypNum; //create distrust score for each sequence
127 return distrustScores;
133 * @param hSigns ~ hypothesis signs (+/-)
134 * @param distrustScores
135 * @param scores ~ input score matrix
139 private byte[] optimiseHypothesis(byte[] hSigns, int[] distrustScores, MatrixI scores)
141 // get maximum distrust score
142 int[] maxes = MiscMath.findMax(distrustScores);
143 int maxDistrustIndex = maxes[0];
144 int maxDistrust = maxes[1];
146 // if hypothesis is not optimal yet
149 //toggle sign for hI with maximum distrust
150 hSigns[maxDistrustIndex] *= -1;
151 // update distrust at same position
152 distrustScores[maxDistrustIndex] *= -1;
154 // also update distrust scores for all hI that were not changed
155 byte hASign = hSigns[maxDistrustIndex];
156 for (int NOTmaxDistrustIndex = 0; NOTmaxDistrustIndex < distrustScores.length; NOTmaxDistrustIndex++)
158 if (NOTmaxDistrustIndex != maxDistrustIndex)
160 byte hBSign = hSigns[NOTmaxDistrustIndex];
161 double cell = scores.getRow(maxDistrustIndex)[NOTmaxDistrustIndex];
163 // distrust only changed if not NaN
164 if (!Double.isNaN(cell))
166 byte cellSign = (byte) Math.signum(cell);
167 // if sign of cell matches hypothesis decrease distrust by 2 because 1 more value supporting and 1 less contradicting
168 // else increase by 2
169 if (cellSign == hASign * hBSign)
171 distrustScores[NOTmaxDistrustIndex] -= 2;
173 distrustScores[NOTmaxDistrustIndex] += 2;
178 //further optimisation necessary
179 return optimiseHypothesis(hSigns, distrustScores, scores);
187 * takes the a symmetric MatrixI as input scores which may contain Double.NaN
188 * approximate the missing values using hypothesis optimisation
192 * @param scores ~ score matrix
196 public MatrixI run ()
198 MatrixI eigenMatrix = scoresOld.copy();
199 MatrixI repMatrix = scoresOld.copy();
203 * Calculate correction factor for 2nd and higher eigenvalue(s).
204 * This correction is NOT needed for the 1st eigenvalue, because the
205 * unknown (=NaN) values of the matrix are approximated by presuming
206 * 1-dimensional vectors as the basis of the matrix interpretation as dot
211 System.out.println("input:");
212 eigenMatrix.print(System.out, "%1.4f ");
213 int matrixWidth = eigenMatrix.width(); // square matrix, so width == height
214 int matrixElementsTotal = (int) Math.pow(matrixWidth, 2); //total number of elemts
216 float correctionFactor = (float) (matrixElementsTotal - eigenMatrix.countNaN()) / (float) matrixElementsTotal;
219 * Calculate hypothetical value (1-dimensional vector) h_i for each
220 * dataset by interpreting the given correlation coefficients as scalar
225 * Memory for current hypothesis concerning sign of each h_i.
226 * List of signs for all h_i in the encoding:
230 * Initial hypothesis: all signs are positive.
232 byte[] hSigns = new byte[matrixWidth];
233 Arrays.fill(hSigns, (byte) 1);
235 //Estimate signs for each h_i by refining hypothesis on signs.
236 hSigns = optimiseHypothesis(hSigns, initialiseDistrusts(hSigns, eigenMatrix), eigenMatrix);
239 //Estimate absolute values for each h_i by determining sqrt of mean of
240 //non-NaN absolute values for every row.
241 //<++> hAbs = //np.sqrt(np.nanmean(np.absolute(eigenMatrix), axis=1))
242 MiscMath.print(eigenMatrix.absolute().meanRow(), "%1.8f");
243 double[] hAbs = MiscMath.sqrt(eigenMatrix.absolute().meanRow()); //np.sqrt(np.nanmean(np.absolute(eigenMatrix), axis=1))
245 //Combine estimated signs with absolute values in obtain total value for
247 double[] hValues = MiscMath.elementwiseMultiply(hSigns, hAbs);
248 //<++>hValues.reshape((1,matrixWidth)); // doesnt it already look like this
250 /*Complement symmetric matrix by using the scalar products of estimated
251 *values of h_i to replace NaN-cells.
252 *Matrix positions that have estimated values
253 *(only for diagonal and upper off-diagonal values, due to the symmetry
254 *the positions of the lower-diagonal values can be inferred).
255 List of tuples (row_idx, column_idx).*/
257 ArrayList<int[]> estimatedPositions = new ArrayList<int[]>();
259 // for off-diagonal cells
260 for (int rowIndex = 0; rowIndex < matrixWidth - 1; rowIndex++)
262 for (int columnIndex = rowIndex + 1; columnIndex < matrixWidth; columnIndex++)
264 double cell = eigenMatrix.getValue(rowIndex, columnIndex);
265 if (Double.isNaN(cell))
267 //calculate scalar product as new cell value
268 cell = hValues[rowIndex] * hValues[columnIndex]; // something is wrong with hAbs andhValues and everything!!!!!!!!!!!!
269 //fill in new value in cell and symmetric partner
270 eigenMatrix.setValue(rowIndex, columnIndex, cell);
271 eigenMatrix.setValue(columnIndex, rowIndex, cell);
272 //save positions of estimated values
273 estimatedPositions.add(new int[]{rowIndex, columnIndex});
278 // for diagonal cells
279 for (int diagonalIndex = 0; diagonalIndex < matrixWidth; diagonalIndex++)
281 double cell = Math.pow(hValues[diagonalIndex], 2);
282 eigenMatrix.setValue(diagonalIndex, diagonalIndex, cell);
283 estimatedPositions.add(new int[]{diagonalIndex, diagonalIndex});
286 /*Refine total values of each h_i:
287 *Initialise h_values of the hypothetical non-existant previous iteration
288 *with the correct format but with impossible values.
289 Needed for exit condition of otherwise endless loop.*/
290 System.out.print("initial values: [ ");
291 for (double h : hValues)
293 System.out.print(String.format("%1.4f, ", h));
295 System.out.println(" ]");
298 double[] hValuesOld = new double[matrixWidth];
300 int iterationCount = 0;
302 // repeat unitl values of h do not significantly change anymore
305 for (int hIndex = 0; hIndex < matrixWidth; hIndex++)
307 //@python newH = np.sum(hValues * eigenMatrix[hIndex]) / np.sum(hValues ** 2)
308 double newH = Arrays.stream(MiscMath.elementwiseMultiply(hValues, eigenMatrix.getRow(hIndex))).sum() / Arrays.stream(MiscMath.elementwiseMultiply(hValues, hValues)).sum();
309 hValues[hIndex] = newH;
312 System.out.print(String.format("iteration %d: [ ", iterationCount));
313 for (double h : hValues)
315 System.out.print(String.format("%1.4f, ", h));
317 System.out.println(" ]");
319 //update values of estimated positions
320 for (int[] pair : estimatedPositions) // pair ~ row, col
322 double newVal = hValues[pair[0]] * hValues[pair[1]];
323 eigenMatrix.setValue(pair[0], pair[1], newVal);
324 eigenMatrix.setValue(pair[1], pair[0], newVal);
329 //exit loop as soon as new values are similar to the last iteration
330 if (MiscMath.allClose(hValues, hValuesOld, 0d, 1e-05d, false))
335 //save hValues for comparison in the next iteration
336 System.arraycopy(hValues, 0, hValuesOld, 0, hValues.length);
339 //-----------------------------
340 //Use complemented symmetric matrix to calculate final representative
343 System.out.println("after estimating:");
344 eigenMatrix.print(System.out, "%1.8f ");
346 //Eigendecomposition.
348 System.out.println("tred");
349 eigenMatrix.print(System.out, "%8.2f");
352 System.out.println("eigenvals");
353 eigenMatrix.printD(System.out, "%2.4f ");
354 System.out.println();
355 System.out.println("tqli");
356 eigenMatrix.print(System.out, "%8.2f");
358 double[] eigenVals = eigenMatrix.getD();
361 TreeMap<Double, double[]> eigenPairs = new TreeMap<>(Comparator.reverseOrder());
362 for (int i = 0; i < eigenVals.length; i++)
364 eigenPairs.put(eigenVals[i], eigenMatrix.getColumn(i));
367 TreeMap<Double, Integer> eigenPairs = new TreeMap<>(Comparator.reverseOrder());
368 for (int i = 0; i < eigenVals.length; i++)
370 eigenPairs.put(eigenVals[i], i);
373 // matrix of representative eigenvectors (each row is a vector)
374 double[][] _repMatrix = new double[eigenVals.length][dim]; //last ones were dim
375 double[][] _oldMatrix = new double[eigenVals.length][dim];
376 double[] correctedEigenValues = new double[dim];
378 //for (Entry<Double, double[]> pair : eigenPairs.entrySet())
380 for (Entry<Double, Integer> pair : eigenPairs.entrySet())
382 double eigenValue = pair.getKey();
383 int column = pair.getValue();
384 double[] eigenVector = eigenMatrix.getColumn(column);
385 //for 2nd and higher eigenvalues
388 eigenValue /= correctionFactor;
391 //correctedEigenValues[dim - l] = eigenValue;
392 correctedEigenValues[l] = eigenValue;
393 for (int j = 0; j < eigenVector.length; j++)
395 //_repMatrix[j][dim - l] = (eigenValue < 0) ? 0.0 : Math.sqrt(eigenValue) * eigenVector[j];
396 _repMatrix[j][l] = (eigenValue < 0) ? 0.0 : - Math.sqrt(eigenValue) * eigenVector[j];
397 double tmpOldScore = scoresOld.getColumn(column)[j];
398 _oldMatrix[j][dim - l - 1] = (Double.isNaN(tmpOldScore)) ? 0.0 : tmpOldScore;
407 System.out.println("correctedEigenValues");
408 MiscMath.print(correctedEigenValues, "%2.4f ");
410 repMatrix = new Matrix(_repMatrix);
411 repMatrix.setD(correctedEigenValues);
412 MatrixI oldMatrix = new Matrix(_oldMatrix); //TODO do i even need it anymore?
414 System.out.println("old matrix");
415 oldMatrix.print(System.out, "%8.2f");
417 System.out.println("scoresOld");
418 scoresOld.print(System.out, "%1.4f ");
420 System.out.println("rep matrix");
421 repMatrix.print(System.out, "%1.8f ");
423 MatrixI dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
424 System.out.println("dot matrix");
425 dotMatrix.print(System.out, "%1.8f ");
427 double rmsd = scoresOld.rmsd(dotMatrix); //TODO do i need this here?
428 System.out.println(rmsd);
430 System.out.println("iteration, rmsd, maxDiff, rmsdDiff");
431 System.out.println(String.format("0, %8.5f, -, -", rmsd));
432 // Refine representative vectors by minimising sum-of-squared deviates between dotMatrix and original score matrix
433 for (int iteration = 1; iteration < 21; iteration++) // arbitrarily set to 20
435 MatrixI repMatrixOLD = repMatrix.copy();
436 MatrixI dotMatrixOLD = dotMatrix.copy();
438 // for all rows/hA in the original matrix
439 for (int hAIndex = 0; hAIndex < oldMatrix.height(); hAIndex++)
441 double[] row = oldMatrix.getRow(hAIndex);
442 double[] hA = repMatrix.getRow(hAIndex); // inverted
444 //find least-squares-solution fo rdifferences between original scores and representative vectors
445 //--> originalToEquasionSystem(hA, hAIndex, repMatrix, row) --> double[]
446 System.out.println(String.format("||||||||||||||||||||||||||||||||||||||||||||\nIteration: %d", iteration));
447 //repMatrix = new Matrix( new double[][]{{ 0.92894902, -0.25013783, -0.0051076 }, { 0.91955135, -0.25024707, -0.00516568}, { 0.90957348, -0.21717002, -0.14259899}, { 0.90298063, -0.21678816, -0.14697814}, { 0.9157065, -0.04646437, 0.10105454}, { 0.9050301, -0.04689785, 0.18964432}, { 0.92498545, 0.03881933, 0.14771523}, { 0.88008842, 0.0142395, 0.11356242}, { 0.94276528, 0.25591474, -0.07190911}, { 0.93939976, 0.23375136, -0.07530558}, { 0.93550782, 0.24635804, -0.07317563}, { 0.92497731, 0.21729928, -0.02361162}});
448 //scoresOld = new Matrix( new double[][]{{ Double.NaN, 0.9914, 0.8879, 0.8803, 0.8528, 0.8481, 0.8434, 0.8174, 0.8174, 0.8232, 0.8148, 0.8114}, {0.9914, Double.NaN, 0.8792, 0.8717, 0.8441, 0.8395, 0.8347, 0.8087, 0.8085, 0.8143, 0.8059, 0.8024}, {0.8879, 0.8792, Double.NaN, 0.9578, 0.8289, 0.8075, 0.8268, 0.7815, 0.8165, 0.8168, 0.8084, 0.7974}, {0.8803, 0.8717, 0.9578, Double.NaN, 0.838, 0.7974, 0.8058, 0.7798, 0.8094, 0.8098, 0.8067, 0.7905}, {0.8528, 0.8441, 0.8289, 0.838, Double.NaN, 0.8954, 0.8412, 0.7879, 0.8418, 0.8384, 0.8389, 0.8556}, {0.8481, 0.8395, 0.8075, 0.7974, 0.8954, Double.NaN, 0.8699, 0.8106, 0.8267, 0.8234, 0.8222, 0.8241}, {0.8434, 0.8347, 0.8268, 0.8058, 0.8412, 0.8699, Double.NaN, 0.869, 0.8745, 0.8583, 0.8661, 0.8593}, {0.8174, 0.8087, 0.7815, 0.7798, 0.7879, 0.8106, 0.869, Double.NaN, 0.8273, 0.8331, 0.8137, 0.8029}, {0.8174, 0.8085, 0.8165, 0.8094, 0.8418, 0.8267, 0.8745, 0.8273, Double.NaN, 0.967, 0.978, 0.9373}, {0.8232, 0.8143, 0.8168, 0.8098, 0.8384, 0.8234, 0.8583, 0.8331, 0.967, Double.NaN, 0.9561, 0.9337}, {0.8148, 0.8059, 0.8084, 0.8067, 0.8389, 0.8222, 0.8661, 0.8137, 0.978, 0.9561, Double.NaN, 0.9263}, {0.8114, 0.8024, 0.7974, 0.7905, 0.8556, 0.8241, 0.8593, 0.8029, 0.9373, 0.9337, 0.9263, Double.NaN}});
449 double[] hAlsm = leastSquaresOptimisation(repMatrix, scoresOld, hAIndex);
450 // update repMatrix with new hAlsm
451 for (int j = 0; j < repMatrix.width(); j++)
453 repMatrix.setValue(hAIndex, j, hAlsm[j]);
458 // dot product of representative vecotrs yields a matrix with values approximating the correlation matrix
459 dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
460 // calculate rmsd between approximation and correlation matrix
461 rmsd = scoresOld.rmsd(dotMatrix);
463 // calculate maximum change of representative vectors of current iteration
464 //repMatrix.subtract(repMatrixOLD).print(System.out, "%8.2f");
465 MatrixI diff = repMatrix.subtract(repMatrixOLD).absolute();
466 double maxDiff = 0.0;
467 for (int i = 0; i < diff.height(); i++)
469 for (int j = 0; j < diff.width(); j++)
471 maxDiff = (diff.getValue(i, j) > maxDiff) ? diff.getValue(i, j) : maxDiff;
474 System.out.println(String.format("maxDiff: %f", maxDiff));
476 // calculate rmsd between current and previous estimation
477 double rmsdDiff = dotMatrix.rmsd(dotMatrixOLD);
479 System.out.println(String.format("%d, %8.5f, %8.5f, %8.5f", iteration, rmsd, maxDiff, rmsdDiff));
481 if (!(Math.abs(maxDiff) > 1e-06))
483 repMatrix = repMatrixOLD.copy();
490 } catch (Exception q)
492 Console.error("Error computing cc_analysis: " + q.getMessage());
495 //repMatrix = repMatrix.transpose();
496 System.out.println("final repMatrix");
497 repMatrix.print(System.out, "%8.2f");
502 * Create equations system using information on originally known
503 * pairwise correlation coefficients (parsed from infile) and the
504 * representative result vectors
506 * Each equation has the format:
507 * hA * hA - pairwiseCC = 0
509 * hA: unknown variable
510 * hB: known representative vector
511 * pairwiseCC: known pairwise correlation coefficien
513 * The resulting equations system is overdetermined, if there are more
514 * equations than unknown elements
516 * @param x ~ unknown n-dimensional column-vector
517 * (needed for generating equations system, NOT to be specified by user).
518 * @param hAIndex ~ index of currently optimised representative result vector.
519 * @param h ~ matrix with row-wise listing of representative result vectors.
520 * @param originalRow ~ matrix-row of originally parsed pairwise correlation coefficients.
524 //private double[] originalToEquasionSystem(double[] x, int hAIndex, MatrixI h, MatrixI originalScores)
525 private double[] originalToEquasionSystem(double[] hA, MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
527 double[] originalRow = scoresOld.getRow(hAIndex);
528 int nans = MiscMath.countNaN(originalRow);
529 double[] result = new double[originalRow.length - nans];
531 //for all pairwiseCC in originalRow
533 for (int hBIndex = 0; hBIndex < originalRow.length; hBIndex++)
535 double pairwiseCC = originalRow[hBIndex];
536 // if not NaN -> create new equation and add it to the system
537 if (!Double.isNaN(pairwiseCC))
539 double[] hB = repMatrix.getRow(hBIndex);
540 result[resultIndex++] = MiscMath.sum(MiscMath.elementwiseMultiply(hA, hB)) - pairwiseCC;
548 * returns the jacobian matrix
549 * @param repMatrix ~ matrix of representative vectors
550 * @param hAIndex ~ current row index
554 private MatrixI approximateDerivative(MatrixI repMatrix, MatrixI scoresOld, int hAIndex)
557 double[] hA = repMatrix.getRow(hAIndex);
558 double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, hAIndex);
559 System.out.println("Approximate derivative with ");
560 System.out.print("hA (x): ");
561 MiscMath.print(hA, "%1.8f");
562 System.out.print("f0: ");
563 MiscMath.print(f0, "%1.8f");
564 double[] signX0 = new double[hA.length];
565 double[] xAbs = new double[hA.length];
566 for (int i = 0; i < hA.length; i++)
568 signX0[i] = (hA[i] >= 0) ? 1 : -1;
569 xAbs[i] = (Math.abs(hA[i]) >= 1.0) ? Math.abs(hA[i]) : 1.0;
571 double rstep = Math.pow(Math.ulp(1.0), 0.5);
573 double[] h = new double [hA.length];
574 for (int i = 0; i < hA.length; i++)
576 h[i] = rstep * signX0[i] * xAbs[i];
581 double[][] jTransposed = new double[n][m];
582 for (int i = 0; i < h.length; i++)
584 double[] x = new double[h.length];
585 System.arraycopy(hA, 0, x, 0, h.length);
587 double dx = x[i] - hA[i];
588 double[] df = originalToEquasionSystem(x, repMatrix, scoresOld, hAIndex);
589 for (int j = 0; j < df.length; j++)
592 jTransposed[i][j] = df[j] / dx;
595 MatrixI J = new Matrix(jTransposed).transpose(); // inverted
600 * norm of regularized (by alpha) least-squares solution minus Delta
608 private double[] phiAndDerivative(double alpha, double[] suf, double[] s, double Delta)
610 double[] denom = MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha);
611 double pNorm = MiscMath.norm(MiscMath.elementwiseDivide(suf, denom));
612 double phi = pNorm - Delta;
613 // - sum ( suf**2 / denom**3) / pNorm
614 double phiPrime = - MiscMath.sum(MiscMath.elementwiseDivide(MiscMath.elementwiseMultiply(suf, suf), MiscMath.elementwiseMultiply(MiscMath.elementwiseMultiply(denom, denom), denom))) / pNorm;
615 return new double[]{phi, phiPrime};
619 * class holding the result of solveLsqTrustRegion
621 private class TrustRegion
623 private double[] step;
624 private double alpha;
625 private int iteration;
627 public TrustRegion(double[] step, double alpha, int iteration)
631 this.iteration = iteration;
634 public double[] getStep()
639 public double getAlpha()
644 public int getIteration()
646 return this.iteration;
651 * solve a trust-region problem arising in least-squares optimisation
652 * @param n ~ number of variables
653 * @param m ~ number of residuals
655 * @param s ~ singular values of J
656 * @param V ~ transpose of VT
657 * @param Delta ~ radius of a trust region
658 * @param alpha ~ initial guess for alpha
662 private TrustRegion solveLsqTrustRegion(int n, int m, double[] uf, double[] s, MatrixI V, double Delta, double alpha)
664 double[] suf = MiscMath.elementwiseMultiply(s, uf);
666 //check if J has full rank and tr Gauss-Newton step
667 boolean fullRank = false;
670 double threshold = s[0] * Math.ulp(1.0) * m;
671 fullRank = s[s.length - 1] > threshold;
675 double[] p = MiscMath.elementwiseMultiply(V.sumProduct(MiscMath.elementwiseDivide(uf, s)), -1); // inverted and roughly fine
676 if (MiscMath.norm(p) <= Delta)
678 TrustRegion result = new TrustRegion(p, 0.0, 0);
683 double alphaUpper = MiscMath.norm(suf) / Delta;
684 double alphaLower = 0.0;
687 double[] phiAndPrime = phiAndDerivative(0.0, suf, s, Delta);
688 alphaLower = - phiAndPrime[0] / phiAndPrime[1];
691 alpha = (!fullRank && alpha == 0.0) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
694 while (iteration < 10) // 10 is default max_iter
696 alpha = (alpha < alphaLower || alpha > alphaUpper) ? alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5)) : alpha;
697 double[] phiAndPrime = phiAndDerivative(alpha, suf, s, Delta);
698 double phi = phiAndPrime[0];
699 double phiPrime = phiAndPrime[1];
701 alphaUpper = (phi < 0) ? alpha : alphaUpper;
702 double ratio = phi / phiPrime;
703 alphaLower = Math.max(alphaLower, alpha - ratio);
704 alpha -= (phi + Delta) * ratio / Delta;
706 if (Math.abs(phi) < 0.01 * Delta) // default rtol set to 0.01
713 // p = - V.dot( suf / (s**2 + alpha))
714 double[] tmp = MiscMath.elementwiseDivide(suf, MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha));
715 double[] p = MiscMath.elementwiseMultiply(V.sumProduct(tmp), -1);
717 // Make the norm of p equal to Delta, p is changed only slightly during this.
718 // It is done to prevent p lie outside of the trust region
719 p = MiscMath.elementwiseMultiply(p, Delta / MiscMath.norm(p));
721 TrustRegion result = new TrustRegion(p, alpha, iteration + 1);
726 * compute values of a quadratic function arising in least squares
727 * function: 0.5 * s.T * (J.T * J + diag) * s + g.T * s
729 * @param J ~ jacobian matrix
730 * @param g ~ gradient
731 * @param s ~ steps and rows
735 private double evaluateQuadratic(MatrixI J, double[] g, double[] s)
738 //TODO s (-> stepH) is slightly different
739 double[] Js = J.sumProduct(s); //TODO completely wromg
740 double q = MiscMath.dot(Js, Js);
741 double l = MiscMath.dot(s, g);
744 System.out.println("doing evaluateQuadratic");
745 System.out.println("inputs");
746 System.out.print("J");
747 J.print(System.out, "%f ");
748 System.out.print("g");
749 MiscMath.print(g, "%f");
750 System.out.print("s");
751 MiscMath.print(s, "%f");
752 System.out.print("\nJs");
753 MiscMath.print(Js, "%f");
754 System.out.println(String.format("0.5 * %f + %f", q, l));
761 * update the radius of a trust region based on the cost reduction
764 * @param actualReduction
765 * @param predictedReduction
771 private double[] updateTrustRegionRadius(double Delta, double actualReduction, double predictedReduction, double stepNorm, boolean boundHit)
774 if (predictedReduction > 0)
776 ratio = actualReduction / predictedReduction;
777 } else if (predictedReduction == 0 && actualReduction == 0) {
785 Delta = 0.25 * stepNorm;
786 } else if (ratio > 0.75 && boundHit) {
790 return new double[]{Delta, ratio};
794 * check the termination condition for nonlinear least squares
796 * @param actualReduction
804 private byte checkTermination(double actualReduction, double cost, double stepNorm, double xNorm, double ratio)
806 // default ftol and xtol = 1e-8
807 boolean ftolSatisfied = actualReduction < (1e-8 * cost) && ratio > 0.25;
808 boolean xtolSatisfied = stepNorm < (1e-8 * (1e-8 + xNorm));
810 if (ftolSatisfied && xtolSatisfied)
813 } else if (ftolSatisfied) {
815 } else if (xtolSatisfied) {
824 * @param repMatrix ~ Matrix containing representative vectors
825 * @param scoresOld ~ Matrix containing initial observations
826 * @param index ~ current row index
827 * @param J ~ jacobian matrix
831 private double[] trf(MatrixI repMatrix, MatrixI scoresOld, int index, MatrixI J)
833 System.out.println("-----------------\nStart of trf");
835 double[] hA = repMatrix.getRow(index); //inverted
836 double[] f0 = originalToEquasionSystem(hA, repMatrix, scoresOld, index);
841 double cost = 0.5 * MiscMath.dot(f0, f0);
842 double[] g = J.transpose().sumProduct(f0); // inverted
843 double[] scale = new double[hA.length];
844 Arrays.fill(scale, 1); // ??
845 double Delta = MiscMath.norm(hA);
846 int maxNfev = hA.length * 100; // ??
847 double alpha = 0.0; // ?? "Levenberg-Marquardt" parameter
849 System.out.println("Checking initial values:");
850 System.out.print("hA (x): ");
851 MiscMath.print(hA, "%1.8f");
852 System.out.print("f0: ");
853 MiscMath.print(f0, "%1.8f");
854 //System.out.println(String.format("nfev: %d, njev: %d, maxNfev: %d", nfev, njev, maxNfev));
855 //System.out.println(String.format("m: %d, n: %d", m, n));
856 System.out.println(String.format("cost: %1.8f, Delta: %1.8f, alpha: %1.8f", cost, Delta, alpha));
857 System.out.print("g: ");
858 MiscMath.print(g, "%1.8f");
861 byte terminationStatus = 0;
864 System.out.println("outer while loop starts");
867 System.out.println(String.format("iteration: %d", iteration));
869 gNorm = MiscMath.norm(g);
870 if (terminationStatus != 0 || nfev == maxNfev)
872 System.out.println(String.format("outer loop broken with terminationStatus: %d and nfev: %d", terminationStatus, nfev));
876 SingularValueDecomposition svd = new SingularValueDecomposition(new Array2DRowRealMatrix(J.asArray()));
877 // svd not 100% correct -> origin of problems
878 MatrixI U = new Matrix(svd.getU().getData()); // inverted
879 double[] s = svd.getSingularValues(); // ??
880 MatrixI V = new Matrix(svd.getV().getData()).transpose(); //TODO inverted origin of probelm
881 double[] uf = U.transpose().sumProduct(f0);
883 System.out.println("After SVD");
884 System.out.println(String.format("gNorm: %1.8f", gNorm));
885 System.out.print("U: ");
886 U.print(System.out, "%1.8f ");
887 System.out.print("s: ");
888 MiscMath.print(s, "%1.8f");
889 System.out.print("V: ");
890 V.print(System.out, "%1.8f ");
891 System.out.print("uf: ");
892 MiscMath.print(uf, "%1.8f");
894 double actualReduction = -1;
895 double[] xNew = new double[hA.length];
896 double[] fNew = new double[f0.length];
898 double stepHnorm = 0;
900 System.out.println("Inner while loop starts");
902 while (actualReduction <= 0 && nfev < maxNfev)
904 TrustRegion trustRegion = solveLsqTrustRegion(n, m, uf, s, V, Delta, alpha);
905 double[] stepH = trustRegion.getStep();
906 alpha = trustRegion.getAlpha();
907 int nIterations = trustRegion.getIteration();
908 double predictedReduction = - (evaluateQuadratic(J, g, stepH));
910 xNew = MiscMath.elementwiseAdd(hA, stepH);
911 fNew = originalToEquasionSystem(xNew, repMatrix, scoresOld, index);
914 stepHnorm = MiscMath.norm(stepH);
916 System.out.println("After TrustRegion");
917 System.out.print("stepH: ");
918 MiscMath.print(stepH, "%1.8f ");
919 System.out.println(String.format("alpha: %1.8f, nIterations: %d, predictedReduction: %1.8f, nfev: %d, stepHnorm: %1.8f", alpha, nIterations, predictedReduction, nfev, stepHnorm));
920 System.out.print("xNew: ");
921 MiscMath.print(xNew, "%1.8f ");
922 System.out.print("fNew: ");
923 MiscMath.print(fNew, "%1.8f ");
925 if (MiscMath.countNaN(fNew) > 0)
927 Delta = 0.25 * stepHnorm;
928 System.out.println(String.format("Loop continued with %d NaNs and Delta: %1.8f", MiscMath.countNaN(fNew), Delta));
932 // usual trust-region step quality estimation
933 costNew = 0.5 * MiscMath.dot(fNew, fNew);
934 actualReduction = cost - costNew;
936 double[] updatedTrustRegion = updateTrustRegionRadius(Delta, actualReduction, predictedReduction, stepHnorm, stepHnorm > (0.95 * Delta));
937 double DeltaNew = updatedTrustRegion[0];
938 double ratio = updatedTrustRegion[1];
940 terminationStatus = checkTermination(actualReduction, cost, stepHnorm, MiscMath.norm(hA), ratio);
941 if (terminationStatus != 0)
946 alpha *= Delta / DeltaNew;
949 System.out.println(String.format("actualReduction: %1.8f, alpha: %1.8f, Delta: %1.8f, termination_status: %d, cost_new: %1.8f", actualReduction, alpha, Delta, terminationStatus, costNew));
952 System.out.println(String.format("actualReduction before check: %1.8f", actualReduction));
953 if (actualReduction > 0)
959 J = approximateDerivative(repMatrix, scoresOld, index);
960 System.out.println("J in the end");
961 J.print(System.out, "%1.8f ");
964 g = J.transpose().sumProduct(f0);
972 System.out.println("into OptimizeResult");
973 System.out.println("x (hA)");
974 MiscMath.print(hA, "%1.8f");
975 System.out.println(String.format("cost: %1.8f", cost));
976 System.out.println("f0");
977 MiscMath.print(f0, "%1.8f");
978 System.out.println("J");
979 J.print(System.out, "%1.8f ");
980 System.out.println("g");
981 MiscMath.print(g, "%1.8f");
982 System.out.println(String.format("gNorm: %1.8f", gNorm));
983 System.out.println(String.format("nfev: %d", nfev));
984 System.out.println(String.format("njev: %d", njev));
985 System.out.println(String.format("terminationStatus: %d", terminationStatus));
986 // OptimizeResult(x, cost, f0, J, g, gNorm, 0 in shape x, nfev, njev, terminationStatus)
992 * @param repMatrix ~ Matrix containing representative vectors
993 * @param scoresOld ~ Matrix containing initial observations
994 * @param index ~ current row index
998 private double[] leastSquaresOptimisation(MatrixI repMatrix, MatrixI scoresOld, int index)
1000 System.out.println("lsq starts!!!");
1001 MatrixI J = approximateDerivative(repMatrix, scoresOld, index);
1002 System.out.println("J");
1003 J.print(System.out, "%1.8f ");
1004 double[] result = trf(repMatrix, scoresOld, index, J);