Next version of JABA
[jabaws.git] / binaries / src / probcons / EvolutionaryTree.h
1 /////////////////////////////////////////////////////////////////
2 // EvolutionaryTree.h
3 //
4 // Utilities for reading/writing multiple sequence data.
5 /////////////////////////////////////////////////////////////////
6
7 #ifndef EVOLUTIONARYTREE_H
8 #define EVOLUTIONARYTREE_H
9
10 #include <string>
11 #include <list>
12 #include <stdio.h>
13 #include "SafeVector.h"
14 #include "MultiSequence.h"
15 #include "Sequence.h"
16
17 using namespace std;
18
19 /////////////////////////////////////////////////////////////////
20 // TreeNode
21 //
22 // The fundamental unit for representing an alignment tree.  The
23 // guide tree is represented as a binary tree.
24 /////////////////////////////////////////////////////////////////
25
26 class TreeNode {
27   int sequenceLabel;                  // sequence label
28   TreeNode *left, *right, *parent;    // pointers to left, right children
29
30   /////////////////////////////////////////////////////////////////
31   // TreeNode::PrintNode()
32   //
33   // Internal routine used to print out the sequence comments
34   // associated with the evolutionary tree, using a hierarchical
35   // parenthesized format.
36   /////////////////////////////////////////////////////////////////
37
38   void PrintNode (ostream &outfile, const MultiSequence *sequences) const {
39
40     // if this is a leaf node, print out the associated sequence comment
41     if (sequenceLabel >= 0)
42       outfile << sequences->GetSequence (sequenceLabel)->GetHeader();
43
44     // otherwise, it must have two children; print out their subtrees recursively
45     else {
46       assert (left);
47       assert (right);
48
49       outfile << "(";
50       left->PrintNode (outfile, sequences);
51       outfile << " ";
52       right->PrintNode (outfile, sequences);
53       outfile << ")";
54     }
55   }
56
57  public:
58
59   /////////////////////////////////////////////////////////////////
60   // TreeNode::TreeNode()
61   //
62   // Constructor for a tree node.  Note that sequenceLabel = -1
63   // implies that the current node is not a leaf in the tree.
64   /////////////////////////////////////////////////////////////////
65
66   TreeNode (int sequenceLabel) : sequenceLabel (sequenceLabel),
67     left (NULL), right (NULL), parent (NULL) {
68     assert (sequenceLabel >= -1);
69   }
70
71   /////////////////////////////////////////////////////////////////
72   // TreeNode::~TreeNode()
73   //
74   // Destructor for a tree node.  Recursively deletes all children.
75   /////////////////////////////////////////////////////////////////
76
77   ~TreeNode (){
78     if (left){ delete left; left = NULL; }
79     if (right){ delete right; right = NULL; }
80     parent = NULL;
81   }
82
83
84   // getters
85   int GetSequenceLabel () const { return sequenceLabel; }
86   TreeNode *GetLeftChild () const { return left; }
87   TreeNode *GetRightChild () const { return right; }
88   TreeNode *GetParent () const { return parent; }
89
90   // setters
91   void SetSequenceLabel (int sequenceLabel){ this->sequenceLabel = sequenceLabel; assert (sequenceLabel >= -1); }
92   void SetLeftChild (TreeNode *left){ this->left = left; }
93   void SetRightChild (TreeNode *right){ this->right = right; }
94   void SetParent (TreeNode *parent){ this->parent = parent; }
95
96   /////////////////////////////////////////////////////////////////
97   // TreeNode::ComputeTree()
98   //
99   // Routine used to compute an evolutionary tree based on the
100   // given distance matrix.  We assume the distance matrix has the
101   // form, distMatrix[i][j] = expected accuracy of aligning i with j.
102   /////////////////////////////////////////////////////////////////
103
104   static TreeNode *ComputeTree (const VVF &distMatrix){
105
106     int numSeqs = distMatrix.size();                 // number of sequences in distance matrix
107     VVF distances (numSeqs, VF (numSeqs));           // a copy of the distance matrix
108     SafeVector<TreeNode *> nodes (numSeqs, NULL);    // list of nodes for each sequence
109     SafeVector<int> valid (numSeqs, 1);              // valid[i] tells whether or not the ith
110                                                      // nodes in the distances and nodes array
111                                                      // are valid
112
113     // initialization: make a copy of the distance matrix
114     for (int i = 0; i < numSeqs; i++)
115       for (int j = 0; j < numSeqs; j++)
116         distances[i][j] = distMatrix[i][j];
117
118     // initialization: create all the leaf nodes
119     for (int i = 0; i < numSeqs; i++){
120       nodes[i] = new TreeNode (i);
121       assert (nodes[i]);
122     }
123
124     // repeat until only a single node left
125     for (int numNodesLeft = numSeqs; numNodesLeft > 1; numNodesLeft--){
126       float bestProb = -1;
127       pair<int,int> bestPair;
128
129       // find the closest pair
130       for (int i = 0; i < numSeqs; i++) if (valid[i]){
131         for (int j = i+1; j < numSeqs; j++) if (valid[j]){
132           if (distances[i][j] > bestProb){
133             bestProb = distances[i][j];
134             bestPair = make_pair(i, j);
135           }
136         }
137       }
138
139       // merge the closest pair
140       TreeNode *newParent = new TreeNode (-1);
141       newParent->SetLeftChild (nodes[bestPair.first]);
142       newParent->SetRightChild (nodes[bestPair.second]);
143       nodes[bestPair.first]->SetParent (newParent);
144       nodes[bestPair.second]->SetParent (newParent);
145       nodes[bestPair.first] = newParent;
146       nodes[bestPair.second] = NULL;
147
148       // now update the distance matrix
149       for (int i = 0; i < numSeqs; i++) if (valid[i]){
150         distances[bestPair.first][i] = distances[i][bestPair.first]
151           = (distances[i][bestPair.first] + distances[i][bestPair.second]) * bestProb / 2;
152       }
153
154       // finally, mark the second node entry as no longer valid
155       valid[bestPair.second] = 0;
156     }
157
158     assert (nodes[0]);
159     return nodes[0];
160   }
161
162   /////////////////////////////////////////////////////////////////
163   // TreeNode::Print()
164   //
165   // Print out the subtree associated with this node in a
166   // parenthesized representation.
167   /////////////////////////////////////////////////////////////////
168
169   void Print (ostream &outfile, const MultiSequence *sequences) const {
170     outfile << "Alignment tree: ";
171     PrintNode (outfile, sequences);
172     outfile << endl;
173   }
174 };
175
176 #endif