b910a7edb4229221c06ea32549489d604b3fcb63
[jalview.git] / src / jalview / math / Matrix.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.math;
22
23 import jalview.util.Format;
24 import jalview.util.MessageManager;
25
26 import java.io.PrintStream;
27
28 /**
29  * A class to model rectangular matrices of double values and operations on them
30  */
31 public class Matrix implements MatrixI
32 {
33   /*
34    * maximum number of iterations for tqli
35    */
36   private static final int MAX_ITER = 45;
37   // fudge - add 15 iterations, just in case
38
39   /*
40    * the number of rows
41    */
42   final protected int rows;
43
44   /*
45    * the number of columns
46    */
47   final protected int cols;
48
49   /*
50    * the cell values in row-major order
51    */
52   private double[][] value;
53
54   protected double[] d; // Diagonal
55
56   protected double[] e; // off diagonal
57
58   /**
59    * Constructor given number of rows and columns
60    * 
61    * @param colCount
62    * @param rowCount
63    */
64   protected Matrix(int rowCount, int colCount)
65   {
66     rows = rowCount;
67     cols = colCount;
68   }
69
70   /**
71    * Creates a new Matrix object containing a copy of the supplied array values.
72    * For example
73    * 
74    * <pre>
75    *   new Matrix(new double[][] {{2, 3, 4}, {5, 6, 7})
76    * constructs
77    *   (2 3 4)
78    *   (5 6 7)
79    * </pre>
80    * 
81    * Note that ragged arrays (with not all rows, or columns, of the same
82    * length), are not supported by this class. They can be constructed, but
83    * results of operations on them are undefined and may throw exceptions.
84    * 
85    * @param values
86    *          the matrix values in row-major order
87    */
88   public Matrix(double[][] values)
89   {
90     this.rows = values.length;
91     this.cols = this.rows == 0 ? 0 : values[0].length;
92
93     /*
94      * make a copy of the values array, for immutability
95      */
96     this.value = new double[rows][];
97     int i = 0;
98     for (double[] row : values)
99     {
100       if (row != null)
101       {
102         value[i] = new double[row.length];
103         System.arraycopy(row, 0, value[i], 0, row.length);
104       }
105       i++;
106     }
107   }
108
109   @Override
110   public MatrixI transpose()
111   {
112     double[][] out = new double[cols][rows];
113
114     for (int i = 0; i < cols; i++)
115     {
116       for (int j = 0; j < rows; j++)
117       {
118         out[i][j] = value[j][i];
119       }
120     }
121
122     return new Matrix(out);
123   }
124
125   /**
126    * DOCUMENT ME!
127    * 
128    * @param ps
129    *          DOCUMENT ME!
130    * @param format
131    */
132   @Override
133   public void print(PrintStream ps, String format)
134   {
135     for (int i = 0; i < rows; i++)
136     {
137       for (int j = 0; j < cols; j++)
138       {
139         Format.print(ps, format, getValue(i, j));
140       }
141
142       ps.println();
143     }
144   }
145
146   @Override
147   public MatrixI preMultiply(MatrixI in)
148   {
149     if (in.width() != rows)
150     {
151       throw new IllegalArgumentException("Can't pre-multiply " + this.rows
152               + " rows by " + in.width() + " columns");
153     }
154     double[][] tmp = new double[in.height()][this.cols];
155
156     for (int i = 0; i < in.height(); i++)
157     {
158       for (int j = 0; j < this.cols; j++)
159       {
160         /*
161          * result[i][j] is the vector product of 
162          * in.row[i] and this.column[j]
163          */
164         for (int k = 0; k < in.width(); k++)
165         {
166           tmp[i][j] += (in.getValue(i, k) * this.value[k][j]);
167         }
168       }
169     }
170
171     return new Matrix(tmp);
172   }
173
174   /**
175    * 
176    * @param in
177    * 
178    * @return
179    */
180   public double[] vectorPostMultiply(double[] in)
181   {
182     double[] out = new double[in.length];
183
184     for (int i = 0; i < in.length; i++)
185     {
186       out[i] = 0.0;
187
188       for (int k = 0; k < in.length; k++)
189       {
190         out[i] += (value[i][k] * in[k]);
191       }
192     }
193
194     return out;
195   }
196
197   @Override
198   public MatrixI postMultiply(MatrixI in)
199   {
200     if (in.height() != this.cols)
201     {
202       throw new IllegalArgumentException("Can't post-multiply " + this.cols
203               + " columns by " + in.height() + " rows");
204     }
205     return in.preMultiply(this);
206   }
207
208   @Override
209   public MatrixI copy()
210   {
211     double[][] newmat = new double[rows][cols];
212
213     for (int i = 0; i < rows; i++)
214     {
215       System.arraycopy(value[i], 0, newmat[i], 0, value[i].length);
216     }
217
218     return new Matrix(newmat);
219   }
220
221   /**
222    * DOCUMENT ME!
223    */
224   @Override
225   public void tred()
226   {
227     int n = rows;
228     int k;
229     int j;
230     int i;
231
232     double scale;
233     double hh;
234     double h;
235     double g;
236     double f;
237
238     this.d = new double[rows];
239     this.e = new double[rows];
240
241     for (i = n; i >= 2; i--)
242     {
243       final int l = i - 1;
244       h = 0.0;
245       scale = 0.0;
246
247       if (l > 1)
248       {
249         for (k = 1; k <= l; k++)
250         {
251           double v = Math.abs(getValue(i - 1, k - 1));
252           scale += v;
253         }
254
255         if (scale == 0.0)
256         {
257           e[i - 1] = getValue(i - 1, l - 1);
258         }
259         else
260         {
261           for (k = 1; k <= l; k++)
262           {
263             double v = divideValue(i - 1, k - 1, scale);
264             h += v * v;
265           }
266
267           f = getValue(i - 1, l - 1);
268
269           if (f > 0)
270           {
271             g = -1.0 * Math.sqrt(h);
272           }
273           else
274           {
275             g = Math.sqrt(h);
276           }
277
278           e[i - 1] = scale * g;
279           h -= (f * g);
280           setValue(i - 1, l - 1, f - g);
281           f = 0.0;
282
283           for (j = 1; j <= l; j++)
284           {
285             double val = getValue(i - 1, j - 1) / h;
286             setValue(j - 1, i - 1, val);
287             g = 0.0;
288
289             for (k = 1; k <= j; k++)
290             {
291               g += (getValue(j - 1, k - 1) * getValue(i - 1, k - 1));
292             }
293
294             for (k = j + 1; k <= l; k++)
295             {
296               g += (getValue(k - 1, j - 1) * getValue(i - 1, k - 1));
297             }
298
299             e[j - 1] = g / h;
300             f += (e[j - 1] * getValue(i - 1, j - 1));
301           }
302
303           hh = f / (h + h);
304
305           for (j = 1; j <= l; j++)
306           {
307             f = getValue(i - 1, j - 1);
308             g = e[j - 1] - (hh * f);
309             e[j - 1] = g;
310
311             for (k = 1; k <= j; k++)
312             {
313               double val = (f * e[k - 1]) + (g * getValue(i - 1, k - 1));
314               addValue(j - 1, k - 1, -val);
315             }
316           }
317         }
318       }
319       else
320       {
321         e[i - 1] = getValue(i - 1, l - 1);
322       }
323
324       d[i - 1] = h;
325     }
326
327     d[0] = 0.0;
328     e[0] = 0.0;
329
330     for (i = 1; i <= n; i++)
331     {
332       final int l = i - 1;
333
334       if (d[i - 1] != 0.0)
335       {
336         for (j = 1; j <= l; j++)
337         {
338           g = 0.0;
339
340           for (k = 1; k <= l; k++)
341           {
342             g += (getValue(i - 1, k - 1) * getValue(k - 1, j - 1));
343           }
344
345           for (k = 1; k <= l; k++)
346           {
347             addValue(k - 1, j - 1, -(g * getValue(k - 1, i - 1)));
348           }
349         }
350       }
351
352       d[i - 1] = getValue(i - 1, i - 1);
353       setValue(i - 1, i - 1, 1.0);
354
355       for (j = 1; j <= l; j++)
356       {
357         setValue(j - 1, i - 1, 0.0);
358         setValue(i - 1, j - 1, 0.0);
359       }
360     }
361   }
362
363   /**
364    * Adds f to the value at [i, j] and returns the new value
365    * 
366    * @param i
367    * @param j
368    * @param f
369    */
370   protected double addValue(int i, int j, double f)
371   {
372     double v = value[i][j] + f;
373     value[i][j] = v;
374     return v;
375   }
376
377   /**
378    * Divides the value at [i, j] by divisor and returns the new value. If d is
379    * zero, returns the unchanged value.
380    * 
381    * @param i
382    * @param j
383    * @param divisor
384    * @return
385    */
386   protected double divideValue(int i, int j, double divisor)
387   {
388     if (divisor == 0d)
389     {
390       return getValue(i, j);
391     }
392     double v = value[i][j];
393     v = v / divisor;
394     value[i][j] = v;
395     return v;
396   }
397
398   /**
399    * DOCUMENT ME!
400    */
401   @Override
402   public void tqli() throws Exception
403   {
404     int n = rows;
405
406     int m;
407     int l;
408     int iter;
409     int i;
410     int k;
411     double s;
412     double r;
413     double p;
414
415     double g;
416     double f;
417     double dd;
418     double c;
419     double b;
420
421     for (i = 2; i <= n; i++)
422     {
423       e[i - 2] = e[i - 1];
424     }
425
426     e[n - 1] = 0.0;
427
428     for (l = 1; l <= n; l++)
429     {
430       iter = 0;
431
432       do
433       {
434         for (m = l; m <= (n - 1); m++)
435         {
436           dd = Math.abs(d[m - 1]) + Math.abs(d[m]);
437
438           if ((Math.abs(e[m - 1]) + dd) == dd)
439           {
440             break;
441           }
442         }
443
444         if (m != l)
445         {
446           iter++;
447
448           if (iter == MAX_ITER)
449           {
450             throw new Exception(MessageManager.formatMessage(
451                     "exception.matrix_too_many_iteration", new String[]
452                     { "tqli", Integer.valueOf(MAX_ITER).toString() }));
453           }
454           else
455           {
456             // System.out.println("Iteration " + iter);
457           }
458
459           g = (d[l] - d[l - 1]) / (2.0 * e[l - 1]);
460           r = Math.sqrt((g * g) + 1.0);
461           g = d[m - 1] - d[l - 1] + (e[l - 1] / (g + sign(r, g)));
462           c = 1.0;
463           s = c;
464           p = 0.0;
465
466           for (i = m - 1; i >= l; i--)
467           {
468             f = s * e[i - 1];
469             b = c * e[i - 1];
470
471             if (Math.abs(f) >= Math.abs(g))
472             {
473               c = g / f;
474               r = Math.sqrt((c * c) + 1.0);
475               e[i] = f * r;
476               s = 1.0 / r;
477               c *= s;
478             }
479             else
480             {
481               s = f / g;
482               r = Math.sqrt((s * s) + 1.0);
483               e[i] = g * r;
484               c = 1.0 / r;
485               s *= c;
486             }
487
488             g = d[i] - p;
489             r = ((d[i - 1] - g) * s) + (2.0 * c * b);
490             p = s * r;
491             d[i] = g + p;
492             g = (c * r) - b;
493
494             for (k = 1; k <= n; k++)
495             {
496               f = getValue(k - 1, i);
497               setValue(k - 1, i, (s * getValue(k - 1, i - 1)) + (c * f));
498               setValue(k - 1, i - 1,
499                       (c * getValue(k - 1, i - 1)) - (s * f));
500             }
501           }
502
503           d[l - 1] = d[l - 1] - p;
504           e[l - 1] = g;
505           e[m - 1] = 0.0;
506         }
507       } while (m != l);
508     }
509   }
510
511   @Override
512   public double getValue(int i, int j)
513   {
514     return value[i][j];
515   }
516
517   @Override
518   public void setValue(int i, int j, double val)
519   {
520     value[i][j] = val;
521   }
522
523   /**
524    * DOCUMENT ME!
525    */
526   public void tred2()
527   {
528     int n = rows;
529     int l;
530     int k;
531     int j;
532     int i;
533
534     double scale;
535     double hh;
536     double h;
537     double g;
538     double f;
539
540     this.d = new double[rows];
541     this.e = new double[rows];
542
543     for (i = n - 1; i >= 1; i--)
544     {
545       l = i - 1;
546       h = 0.0;
547       scale = 0.0;
548
549       if (l > 0)
550       {
551         for (k = 0; k < l; k++)
552         {
553           scale += Math.abs(value[i][k]);
554         }
555
556         if (scale == 0.0)
557         {
558           e[i] = value[i][l];
559         }
560         else
561         {
562           for (k = 0; k < l; k++)
563           {
564             value[i][k] /= scale;
565             h += (value[i][k] * value[i][k]);
566           }
567
568           f = value[i][l];
569
570           if (f > 0)
571           {
572             g = -1.0 * Math.sqrt(h);
573           }
574           else
575           {
576             g = Math.sqrt(h);
577           }
578
579           e[i] = scale * g;
580           h -= (f * g);
581           value[i][l] = f - g;
582           f = 0.0;
583
584           for (j = 0; j < l; j++)
585           {
586             value[j][i] = value[i][j] / h;
587             g = 0.0;
588
589             for (k = 0; k < j; k++)
590             {
591               g += (value[j][k] * value[i][k]);
592             }
593
594             for (k = j; k < l; k++)
595             {
596               g += (value[k][j] * value[i][k]);
597             }
598
599             e[j] = g / h;
600             f += (e[j] * value[i][j]);
601           }
602
603           hh = f / (h + h);
604
605           for (j = 0; j < l; j++)
606           {
607             f = value[i][j];
608             g = e[j] - (hh * f);
609             e[j] = g;
610
611             for (k = 0; k < j; k++)
612             {
613               value[j][k] -= ((f * e[k]) + (g * value[i][k]));
614             }
615           }
616         }
617       }
618       else
619       {
620         e[i] = value[i][l];
621       }
622
623       d[i] = h;
624     }
625
626     d[0] = 0.0;
627     e[0] = 0.0;
628
629     for (i = 0; i < n; i++)
630     {
631       l = i - 1;
632
633       if (d[i] != 0.0)
634       {
635         for (j = 0; j < l; j++)
636         {
637           g = 0.0;
638
639           for (k = 0; k < l; k++)
640           {
641             g += (value[i][k] * value[k][j]);
642           }
643
644           for (k = 0; k < l; k++)
645           {
646             value[k][j] -= (g * value[k][i]);
647           }
648         }
649       }
650
651       d[i] = value[i][i];
652       value[i][i] = 1.0;
653
654       for (j = 0; j < l; j++)
655       {
656         value[j][i] = 0.0;
657         value[i][j] = 0.0;
658       }
659     }
660   }
661
662   /**
663    * DOCUMENT ME!
664    */
665   public void tqli2() throws Exception
666   {
667     int n = rows;
668
669     int m;
670     int l;
671     int iter;
672     int i;
673     int k;
674     double s;
675     double r;
676     double p;
677     ;
678
679     double g;
680     double f;
681     double dd;
682     double c;
683     double b;
684
685     for (i = 2; i <= n; i++)
686     {
687       e[i - 2] = e[i - 1];
688     }
689
690     e[n - 1] = 0.0;
691
692     for (l = 1; l <= n; l++)
693     {
694       iter = 0;
695
696       do
697       {
698         for (m = l; m <= (n - 1); m++)
699         {
700           dd = Math.abs(d[m - 1]) + Math.abs(d[m]);
701
702           if ((Math.abs(e[m - 1]) + dd) == dd)
703           {
704             break;
705           }
706         }
707
708         if (m != l)
709         {
710           iter++;
711
712           if (iter == MAX_ITER)
713           {
714             throw new Exception(MessageManager.formatMessage(
715                     "exception.matrix_too_many_iteration", new String[]
716                     { "tqli2", Integer.valueOf(MAX_ITER).toString() }));
717           }
718           else
719           {
720             // System.out.println("Iteration " + iter);
721           }
722
723           g = (d[l] - d[l - 1]) / (2.0 * e[l - 1]);
724           r = Math.sqrt((g * g) + 1.0);
725           g = d[m - 1] - d[l - 1] + (e[l - 1] / (g + sign(r, g)));
726           c = 1.0;
727           s = c;
728           p = 0.0;
729
730           for (i = m - 1; i >= l; i--)
731           {
732             f = s * e[i - 1];
733             b = c * e[i - 1];
734
735             if (Math.abs(f) >= Math.abs(g))
736             {
737               c = g / f;
738               r = Math.sqrt((c * c) + 1.0);
739               e[i] = f * r;
740               s = 1.0 / r;
741               c *= s;
742             }
743             else
744             {
745               s = f / g;
746               r = Math.sqrt((s * s) + 1.0);
747               e[i] = g * r;
748               c = 1.0 / r;
749               s *= c;
750             }
751
752             g = d[i] - p;
753             r = ((d[i - 1] - g) * s) + (2.0 * c * b);
754             p = s * r;
755             d[i] = g + p;
756             g = (c * r) - b;
757
758             for (k = 1; k <= n; k++)
759             {
760               f = value[k - 1][i];
761               value[k - 1][i] = (s * value[k - 1][i - 1]) + (c * f);
762               value[k - 1][i - 1] = (c * value[k - 1][i - 1]) - (s * f);
763             }
764           }
765
766           d[l - 1] = d[l - 1] - p;
767           e[l - 1] = g;
768           e[m - 1] = 0.0;
769         }
770       } while (m != l);
771     }
772   }
773
774   /**
775    * Answers the first argument with the sign of the second argument
776    * 
777    * @param a
778    * @param b
779    * 
780    * @return
781    */
782   static double sign(double a, double b)
783   {
784     if (b < 0)
785     {
786       return -Math.abs(a);
787     }
788     else
789     {
790       return Math.abs(a);
791     }
792   }
793
794   /**
795    * Returns an array containing the values in the specified column
796    * 
797    * @param col
798    * 
799    * @return
800    */
801   public double[] getColumn(int col)
802   {
803     double[] out = new double[rows];
804
805     for (int i = 0; i < rows; i++)
806     {
807       out[i] = value[i][col];
808     }
809
810     return out;
811   }
812
813   /**
814    * DOCUMENT ME!
815    * 
816    * @param ps
817    *          DOCUMENT ME!
818    * @param format
819    */
820   @Override
821   public void printD(PrintStream ps, String format)
822   {
823     for (int j = 0; j < rows; j++)
824     {
825       Format.print(ps, format, d[j]);
826     }
827   }
828
829   /**
830    * DOCUMENT ME!
831    * 
832    * @param ps
833    *          DOCUMENT ME!
834    * @param format
835    *          TODO
836    */
837   @Override
838   public void printE(PrintStream ps, String format)
839   {
840     for (int j = 0; j < rows; j++)
841     {
842       Format.print(ps, format, e[j]);
843     }
844   }
845
846   @Override
847   public double[] getD()
848   {
849     return d;
850   }
851
852   @Override
853   public double[] getE()
854   {
855     return e;
856   }
857
858   @Override
859   public int height()
860   {
861     return rows;
862   }
863
864   @Override
865   public int width()
866   {
867     return cols;
868   }
869
870   @Override
871   public double[] getRow(int i)
872   {
873     double[] row = new double[cols];
874     System.arraycopy(value[i], 0, row, 0, cols);
875     return row;
876   }
877
878   /**
879    * Returns a length 2 array of {minValue, maxValue} of all values in the
880    * matrix. Returns null if the matrix is null or empty.
881    * 
882    * @return
883    */
884   double[] findMinMax()
885   {
886     if (value == null)
887     {
888       return null;
889     }
890     double min = Double.MAX_VALUE;
891     double max = -Double.MAX_VALUE;
892     boolean empty = true;
893     for (double[] row : value)
894     {
895       if (row != null)
896       {
897         for (double x : row)
898         {
899           empty = false;
900           if (x > max)
901           {
902             max = x;
903           }
904           if (x < min)
905           {
906             min = x;
907           }
908         }
909       }
910     }
911     return empty ? null : new double[] { min, max };
912   }
913
914   /**
915    * {@inheritDoc}
916    */
917   @Override
918   public void reverseRange(boolean maxToZero)
919   {
920     if (value == null)
921     {
922       return;
923     }
924     double[] minMax = findMinMax();
925     if (minMax == null)
926     {
927       return; // empty matrix
928     }
929     double subtractFrom = maxToZero ? minMax[1] : minMax[0] + minMax[1];
930
931     for (double[] row : value)
932     {
933       if (row != null)
934       {
935         int j = 0;
936         for (double x : row)
937         {
938           row[j] = subtractFrom - x;
939           j++;
940         }
941       }
942     }
943   }
944
945   /**
946    * Multiplies every entry in the matrix by the given value.
947    * 
948    * @param
949    */
950   @Override
951   public void multiply(double by)
952   {
953     for (double[] row : value)
954     {
955       if (row != null)
956       {
957         for (int i = 0; i < row.length; i++)
958         {
959           row[i] *= by;
960         }
961       }
962     }
963   }
964 }