Next version of JABA
[jabaws.git] / binaries / src / probcons / SparseMatrix.h
1 /////////////////////////////////////////////////////////////////
2 // SparseMatrix.h
3 //
4 // Sparse matrix computations
5 /////////////////////////////////////////////////////////////////
6
7 #ifndef SPARSEMATRIX_H
8 #define SPARSEMATRIX_H
9
10 #include <iostream>
11
12 using namespace std;
13
14 const float POSTERIOR_CUTOFF = 0.01;         // minimum posterior probability
15                                              // value that is maintained in the
16                                              // sparse matrix representation
17
18 typedef pair<int,float> PIF;                 // Sparse matrix entry type
19                                              //   first --> column
20                                              //   second --> value
21
22 /////////////////////////////////////////////////////////////////
23 // SparseMatrix
24 //
25 // Class for sparse matrix computations
26 /////////////////////////////////////////////////////////////////
27
28 class SparseMatrix {
29
30   int seq1Length, seq2Length;                     // dimensions of matrix
31   VI rowSize;                                     // rowSize[i] = # of cells in row i
32   SafeVector<PIF> data;                           // data values
33   SafeVector<SafeVector<PIF>::iterator> rowPtrs;  // pointers to the beginning of each row
34
35   /////////////////////////////////////////////////////////////////
36   // SparseMatrix::SparseMatrix()
37   //
38   // Private constructor.
39   /////////////////////////////////////////////////////////////////
40
41   SparseMatrix (){}
42
43  public:
44
45   /////////////////////////////////////////////////////////////////
46   // SparseMatrix::SparseMatrix()
47   //
48   // Constructor.  Builds a sparse matrix from a posterior matrix.
49   // Note that the expected format for the posterior matrix is as
50   // a (seq1Length+1) x (seq2Length+1) matrix where the 0th row
51   // and 0th column are ignored (they should contain all zeroes).
52   /////////////////////////////////////////////////////////////////
53
54   SparseMatrix (int seq1Length, int seq2Length, const VF &posterior) :
55     seq1Length (seq1Length), seq2Length (seq2Length) {
56
57     int numCells = 0;
58
59     assert (seq1Length > 0);
60     assert (seq2Length > 0);
61
62     // calculate memory required; count the number of cells in the
63     // posterior matrix above the threshold
64     VF::const_iterator postPtr = posterior.begin();
65     for (int i = 0; i <= seq1Length; i++){
66       for (int j = 0; j <= seq2Length; j++){
67         if (*(postPtr++) >= POSTERIOR_CUTOFF){
68           assert (i != 0 && j != 0);
69           numCells++;
70         }
71       }
72     }
73     
74     // allocate memory
75     data.resize(numCells);
76     rowSize.resize (seq1Length + 1); rowSize[0] = -1;
77     rowPtrs.resize (seq1Length + 1); rowPtrs[0] = data.end();
78
79     // build sparse matrix
80     postPtr = posterior.begin() + seq2Length + 1;           // note that we're skipping the first row here
81     SafeVector<PIF>::iterator dataPtr = data.begin();
82     for (int i = 1; i <= seq1Length; i++){
83       postPtr++;                                            // and skipping the first column of each row
84       rowPtrs[i] = dataPtr;
85       for (int j = 1; j <= seq2Length; j++){
86         if (*postPtr >= POSTERIOR_CUTOFF){
87           dataPtr->first = j;
88           dataPtr->second = *postPtr;
89           dataPtr++;
90         }
91         postPtr++;
92       }
93       rowSize[i] = dataPtr - rowPtrs[i];
94     }
95   }
96
97   /////////////////////////////////////////////////////////////////
98   // SparseMatrix::GetRowPtr()
99   //
100   // Returns the pointer to a particular row in the sparse matrix.
101   /////////////////////////////////////////////////////////////////
102
103   SafeVector<PIF>::iterator GetRowPtr (int row) const {
104     assert (row >= 1 && row <= seq1Length);
105     return rowPtrs[row];
106   }
107
108   /////////////////////////////////////////////////////////////////
109   // SparseMatrix::GetValue()
110   //
111   // Returns value at a particular row, column.
112   /////////////////////////////////////////////////////////////////
113
114   float GetValue (int row, int col){
115     assert (row >= 1 && row <= seq1Length);
116     assert (col >= 1 && col <= seq2Length);
117     for (int i = 0; i < rowSize[row]; i++){
118       if (rowPtrs[row][i].first == col) return rowPtrs[row][i].second;
119     }
120     return 0;
121   }
122
123   /////////////////////////////////////////////////////////////////
124   // SparseMatrix::GetRowSize()
125   //
126   // Returns the number of entries in a particular row.
127   /////////////////////////////////////////////////////////////////
128
129   int GetRowSize (int row) const {
130     assert (row >= 1 && row <= seq1Length);
131     return rowSize[row];
132   }
133
134   /////////////////////////////////////////////////////////////////
135   // SparseMatrix::GetSeq1Length()
136   //
137   // Returns the first dimension of the matrix.
138   /////////////////////////////////////////////////////////////////
139
140   int GetSeq1Length () const {
141     return seq1Length;
142   }
143
144   /////////////////////////////////////////////////////////////////
145   // SparseMatrix::GetSeq2Length()
146   //
147   // Returns the second dimension of the matrix.
148   /////////////////////////////////////////////////////////////////
149
150   int GetSeq2Length () const {
151     return seq2Length;
152   }
153
154   /////////////////////////////////////////////////////////////////
155   // SparseMatrix::GetRowPtr
156   //
157   // Returns the pointer to a particular row in the sparse matrix.
158   /////////////////////////////////////////////////////////////////
159
160   int GetNumCells () const {
161     return data.size();
162   }
163
164   /////////////////////////////////////////////////////////////////
165   // SparseMatrix::Print()
166   //
167   // Prints out a sparse matrix.
168   /////////////////////////////////////////////////////////////////
169
170   void Print (ostream &outfile) const {
171     outfile << "Sparse Matrix:" << endl;
172     for (int i = 1; i <= seq1Length; i++){
173       outfile << "  " << i << ":";
174       for (int j = 0; j < rowSize[i]; j++){
175         outfile << " (" << rowPtrs[i][j].first << "," << rowPtrs[i][j].second << ")";
176       }
177       outfile << endl;
178     }
179   }
180
181   /////////////////////////////////////////////////////////////////
182   // SparseMatrix::ComputeTranspose()
183   //
184   // Returns a new sparse matrix containing the transpose of the
185   // current matrix.
186   /////////////////////////////////////////////////////////////////
187
188   SparseMatrix *ComputeTranspose () const {
189
190     // create a new sparse matrix
191     SparseMatrix *ret = new SparseMatrix();
192     int numCells = data.size();
193
194     ret->seq1Length = seq2Length;
195     ret->seq2Length = seq1Length;
196
197     // allocate memory
198     ret->data.resize (numCells);
199     ret->rowSize.resize (seq2Length + 1); ret->rowSize[0] = -1;
200     ret->rowPtrs.resize (seq2Length + 1); ret->rowPtrs[0] = ret->data.end();
201
202     // compute row sizes
203     for (int i = 1; i <= seq2Length; i++) ret->rowSize[i] = 0;
204     for (int i = 0; i < numCells; i++)
205       ret->rowSize[data[i].first]++;
206
207     // compute row ptrs
208     for (int i = 1; i <= seq2Length; i++){
209       ret->rowPtrs[i] = (i == 1) ? ret->data.begin() : ret->rowPtrs[i-1] + ret->rowSize[i-1];
210     }
211
212     // now fill in data
213     SafeVector<SafeVector<PIF>::iterator> currPtrs = ret->rowPtrs;
214
215     for (int i = 1; i <= seq1Length; i++){
216       SafeVector<PIF>::iterator row = rowPtrs[i];
217       for (int j = 0; j < rowSize[i]; j++){
218         currPtrs[row[j].first]->first = i;
219         currPtrs[row[j].first]->second = row[j].second;
220         currPtrs[row[j].first]++;
221       }
222     }
223
224     return ret;
225   }
226
227   /////////////////////////////////////////////////////////////////
228   // SparseMatrix::GetPosterior()
229   //
230   // Return the posterior representation of the sparse matrix.
231   /////////////////////////////////////////////////////////////////
232
233   VF *GetPosterior () const {
234
235     // create a new posterior matrix
236     VF *posteriorPtr = new VF((seq1Length+1) * (seq2Length+1)); assert (posteriorPtr);
237     VF &posterior = *posteriorPtr;
238
239     // build the posterior matrix
240     for (int i = 0; i < (seq1Length+1) * (seq2Length+1); i++) posterior[i] = 0;
241     for (int i = 1; i <= seq1Length; i++){
242       VF::iterator postPtr = posterior.begin() + i * (seq2Length+1);
243       for (int j = 0; j < rowSize[i]; j++){
244         postPtr[rowPtrs[i][j].first] = rowPtrs[i][j].second;
245       }
246     }
247
248     return posteriorPtr;
249   }
250
251 };
252
253 #endif