new mafft v 6.857 with extensions
[jabaws.git] / binaries / src / mafft / extensions / mxscarna_src / probconsRNA / EvolutionaryTree.h
1 /////////////////////////////////////////////////////////////////
2 // EvolutionaryTree.hpp
3 //
4 // Utilities for reading/writing multiple sequence data.
5 /////////////////////////////////////////////////////////////////
6
7 #ifndef __EVOLUTIONARYTREE_HPP__
8 #define __EVOLUTIONARYTREE_HPP__
9
10 #include <string>
11 #include <list>
12 #include <stdio.h>
13 #include "SafeVector.h"
14 #include "MultiSequence.h"
15 #include "Sequence.h"
16 #include "Util.hpp"
17
18 using namespace std;
19
20
21 /////////////////////////////////////////////////////////////////
22 // TreeNode
23 //
24 // The fundamental unit for representing an alignment tree.  The
25 // guide tree is represented as a binary tree.
26 /////////////////////////////////////////////////////////////////
27 namespace MXSCARNA {
28 class TreeNode {
29   int sequenceLabel;                  // sequence label
30   float sequenceIdentity;             // sequence identity
31   TreeNode *left, *right, *parent;    // pointers to left, right children
32   float leftLength, rightLength;      // the length of left and right edge
33   /////////////////////////////////////////////////////////////////
34   // TreeNode::PrintNode()
35   //
36   // Internal routine used to print out the sequence comments
37   // associated with the evolutionary tree, using a hierarchical
38   // parenthesized format.
39   /////////////////////////////////////////////////////////////////
40
41   void PrintNode (ostream &outfile, const MultiSequence *sequences) const {
42
43     // if this is a leaf node, print out the associated sequence comment
44     if (sequenceLabel >= 0)
45         //outfile << sequences->GetSequence (sequenceLabel)->GetHeader();
46         outfile << sequences->GetSequence (sequenceLabel)->GetLabel();
47
48     // otherwise, it must have two children; print out their subtrees recursively
49     else {
50       assert (left);
51       assert (right);
52
53       outfile << "(";
54       left->PrintNode (outfile, sequences);
55       outfile << ",";
56       right->PrintNode (outfile, sequences);
57       outfile << ")";
58     }
59   }
60
61  public:
62
63   /////////////////////////////////////////////////////////////////
64   // TreeNode::TreeNode()
65   //
66   // Constructor for a tree node.  Note that sequenceLabel = -1
67   // implies that the current node is not a leaf in the tree.
68   /////////////////////////////////////////////////////////////////
69
70   TreeNode (int sequenceLabel) : sequenceLabel (sequenceLabel),
71     left (NULL), right (NULL), parent (NULL) {
72     assert (sequenceLabel >= -1);
73   }
74
75   /////////////////////////////////////////////////////////////////
76   // TreeNode::~TreeNode()
77   //
78   // Destructor for a tree node.  Recursively deletes all children.
79   /////////////////////////////////////////////////////////////////
80
81   ~TreeNode (){
82     if (left){ delete left; left = NULL; }
83     if (right){ delete right; right = NULL; }
84     parent = NULL;
85   }
86
87
88   // getters
89   int GetSequenceLabel () const { return sequenceLabel; }
90   TreeNode *GetLeftChild () const { return left; }
91   TreeNode *GetRightChild () const { return right; }
92   TreeNode *GetParent () const { return parent; }
93   float GetIdentity () const { return sequenceIdentity; }
94   float GetLeftLength () const { return leftLength; }
95   float GetRightLength () const { return rightLength; }
96   // setters
97   void SetSequenceLabel (int sequenceLabel){ this->sequenceLabel = sequenceLabel; assert (sequenceLabel >= -1); }
98   void SetLeftChild (TreeNode *left){ this->left = left; }
99   void SetRightChild (TreeNode *right){ this->right = right; }
100   void SetParent (TreeNode *parent){ this->parent = parent; }
101   void SetIdentity (float identity) { this->sequenceIdentity = identity; }
102   void SetLeftLength (float identity) { this->leftLength = identity; }
103   void SetRightLength (float identity) {this->rightLength = identity; }
104   /////////////////////////////////////////////////////////////////
105   // TreeNode::ComputeTree()
106   //
107   // Routine used to compute an evolutionary tree based on the
108   // given distance matrix.  We assume the distance matrix has the
109   // form, distMatrix[i][j] = expected accuracy of aligning i with j.
110   /////////////////////////////////////////////////////////////////
111
112   static TreeNode *ComputeTree (const VVF &distMatrix, const VVF &identityMatrix){
113
114     int numSeqs = distMatrix.size();                 // number of sequences in distance matrix
115     VVF distances (numSeqs, VF (numSeqs));           // a copy of the distance matrix
116     SafeVector<TreeNode *> nodes (numSeqs, NULL);    // list of nodes for each sequence
117     SafeVector<int> valid (numSeqs, 1);              // valid[i] tells whether or not the ith
118                                                      // nodes in the distances and nodes array
119                                                      // are valid
120     VVF identities (numSeqs, VF (numSeqs));
121     SafeVector<int> countCluster (numSeqs, 1);
122
123     // initialization: make a copy of the distance matrix
124     for (int i = 0; i < numSeqs; i++) {
125         for (int j = 0; j < numSeqs; j++) {
126             distances[i][j]  = distMatrix[i][j];
127             identities[i][j] = identityMatrix[i][j];
128         }
129     }
130
131     // initialization: create all the leaf nodes
132     for (int i = 0; i < numSeqs; i++){
133       nodes[i] = new TreeNode (i);
134       assert (nodes[i]);
135     }
136
137     // repeat until only a single node left
138     for (int numNodesLeft = numSeqs; numNodesLeft > 1; numNodesLeft--){
139         float bestProb = -1;
140       pair<int,int> bestPair;
141
142       // find the closest pair
143       for (int i = 0; i < numSeqs; i++) if (valid[i]){
144         for (int j = i+1; j < numSeqs; j++) if (valid[j]){
145           if (distances[i][j] > bestProb){
146             bestProb = distances[i][j];
147             bestPair = make_pair(i, j);
148           }
149         }
150       }
151
152       // merge the closest pair
153       TreeNode *newParent = new TreeNode (-1);
154       newParent->SetLeftChild (nodes[bestPair.first]);
155       newParent->SetRightChild (nodes[bestPair.second]);
156       nodes[bestPair.first]->SetParent (newParent);
157       nodes[bestPair.second]->SetParent (newParent);
158       nodes[bestPair.first] = newParent;
159       nodes[bestPair.second] = NULL;
160       newParent->SetIdentity(identities[bestPair.first][bestPair.second]);
161
162
163       // now update the distance matrix
164       for (int i = 0; i < numSeqs; i++) if (valid[i]){
165           distances[bestPair.first][i] = distances[i][bestPair.first]
166             = (distances[i][bestPair.first]*countCluster[bestPair.first] 
167                + distances[i][bestPair.second]*countCluster[bestPair.second])
168                / (countCluster[bestPair.first] + countCluster[bestPair.second]);
169 //        distances[bestPair.first][i] = distances[i][bestPair.first]
170 //            = (distances[i][bestPair.first] + distances[i][bestPair.second]) * bestProb / 2;
171         identities[bestPair.first][i] = identities[i][bestPair.first]
172             = (identities[i][bestPair.first]*countCluster[bestPair.first] 
173                + identities[i][bestPair.second]*countCluster[bestPair.second])
174                / (countCluster[bestPair.first] + countCluster[bestPair.second]);
175       }
176
177       // finally, mark the second node entry as no longer valid
178       countCluster[bestPair.first] += countCluster[bestPair.second];
179       valid[bestPair.second] = 0;
180     }
181
182     assert (nodes[0]);
183     return nodes[0];
184   }
185
186   /////////////////////////////////////////////////////////////////
187   // TreeNode::Print()
188   //
189   // Print out the subtree associated with this node in a
190   // parenthesized representation.
191   /////////////////////////////////////////////////////////////////
192
193   void Print (ostream &outfile, const MultiSequence *sequences) const {
194 //    outfile << "Alignment tree: ";
195     PrintNode (outfile, sequences);
196     outfile << endl;
197   }
198 };
199 }
200 #endif //__EVOLUTIONARYTREE_HPP__