Wrapper for Clustal Omega.
[jabaws.git] / binaries / src / clustalo / src / kmpp / KmTree.cpp
1 // See KmTree.cpp
2 //
3 // Author: David Arthur (darthur@gmail.com), 2009
4
5 // Includes
6 #include "KmTree.h"
7 #include <iostream>
8 #include <stdlib.h>
9 #include <stdio.h>
10 using namespace std;
11
12 KmTree::KmTree(int n, int d, Scalar *points): n_(n), d_(d), points_(points) {
13   // Initialize memory
14   // DD: need to cast to long otherwise malloc will fail
15   // if we need more than 2 gigabytes or so
16   int node_size = sizeof(Node) + d_ * 3 * sizeof(Scalar);
17   node_data_ = (char*)malloc((2*(long unsigned int)n-1) * node_size);
18   point_indices_ = (int*)malloc(n * sizeof(int));
19   for (int i = 0; i < n; i++)
20     point_indices_[i] = i;
21   KM_ASSERT(node_data_ != 0 && point_indices_ != 0);
22
23   // Calculate the bounding box for the points
24   Scalar *bound_v1 = PointAllocate(d_);
25   Scalar *bound_v2 = PointAllocate(d_);
26   KM_ASSERT(bound_v1 != 0 && bound_v2 != 0);
27   PointCopy(bound_v1, points, d_);
28   PointCopy(bound_v2, points, d_);
29   for (int i = 1; i < n; i++)
30   for (int j = 0; j < d; j++) {
31     if (bound_v1[j] > points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
32     if (bound_v2[j] < points[i*d_ + j]) bound_v1[j] = points[i*d_ + j];
33   }
34
35   // Build the tree
36   char *temp_node_data = node_data_;
37   top_node_ = BuildNodes(points, 0, n-1, &temp_node_data);
38
39   // Cleanup
40   PointFree(bound_v1);
41   PointFree(bound_v2);
42 }
43
44 KmTree::~KmTree() {
45   free(point_indices_);
46   free(node_data_);
47 }
48
49 Scalar KmTree::DoKMeansStep(int k, Scalar *centers, int *assignment) const {
50   // Create an invalid center for comparison purposes
51   Scalar *bad_center = PointAllocate(d_);
52   KM_ASSERT(bad_center != 0);
53   memset(bad_center, 0xff, d_ * sizeof(Scalar));
54
55   // Allocate data
56   Scalar *sums = (Scalar*)calloc(k * d_, sizeof(Scalar));
57   int *counts = (int*)calloc(k, sizeof(int));
58   int num_candidates = 0;
59   int *candidates = (int*)malloc(k * sizeof(int));
60   KM_ASSERT(sums != 0 && counts != 0 && candidates != 0);
61   for (int i = 0; i < k; i++)
62   if (memcmp(centers + i*d_, bad_center, d_ * sizeof(Scalar)) != 0)
63     candidates[num_candidates++] = i;
64
65   // Find nodes
66   Scalar result = DoKMeansStepAtNode(top_node_, num_candidates, candidates, centers, sums,
67                                      counts, assignment);
68
69   // Set the new centers
70   for (int i = 0; i < k; i++) {
71     if (counts[i] > 0) {
72       PointScale(sums + i*d_, Scalar(1) / counts[i], d_);
73       PointCopy(centers + i*d_, sums + i*d_, d_);
74     } else {
75       memcpy(centers + i*d_, bad_center, d_ * sizeof(Scalar));
76     }
77   }
78
79   // Cleanup memory
80   PointFree(bad_center);
81   free(candidates);
82   free(counts);
83   free(sums);
84   return result;
85 }
86
87 // Helper functions for constructor
88 // ================================
89
90 // Build a kd tree from the given set of points
91 KmTree::Node *KmTree::BuildNodes(Scalar *points, int first_index, int last_index,
92                                  char **next_node_data) {
93   // Allocate the node
94   Node *node = (Node*)(*next_node_data);
95   (*next_node_data) += sizeof(Node);
96   node->sum = (Scalar*)(*next_node_data);
97   (*next_node_data) += sizeof(Scalar) * d_;
98   node->median = (Scalar*)(*next_node_data);
99   (*next_node_data) += sizeof(Scalar) * d_;
100   node->radius = (Scalar*)(*next_node_data);
101   (*next_node_data) += sizeof(Scalar) * d_;
102
103   // Fill in basic info
104   node->num_points = (last_index - first_index + 1);
105   node->first_point_index = first_index;
106
107   // Calculate the bounding box
108   Scalar *first_point = points + point_indices_[first_index] * d_;
109   Scalar *bound_p1 = PointAllocate(d_);
110   Scalar *bound_p2 = PointAllocate(d_);
111   KM_ASSERT(bound_p1 != 0 && bound_p2 != 0);
112   PointCopy(bound_p1, first_point, d_);
113   PointCopy(bound_p2, first_point, d_);
114   for (int i = first_index+1; i <= last_index; i++)
115   for (int j = 0; j < d_; j++) {
116     Scalar c = points[point_indices_[i]*d_ + j];
117     if (bound_p1[j] > c) bound_p1[j] = c;
118     if (bound_p2[j] < c) bound_p2[j] = c;
119   }
120
121   // Calculate bounding box stats and delete the bounding box memory
122   Scalar max_radius = -1;
123   int split_d = -1;
124   for (int j = 0; j < d_; j++) {
125     node->median[j] = (bound_p1[j] + bound_p2[j]) / 2;
126     node->radius[j] = (bound_p2[j] - bound_p1[j]) / 2;
127     if (node->radius[j] > max_radius) {
128       max_radius = node->radius[j];
129       split_d = j;
130     }
131   }
132   PointFree(bound_p2);
133   PointFree(bound_p1);
134
135   // If the max spread is 0, make this a leaf node
136   if (max_radius == 0) {
137     node->lower_node = node->upper_node = 0;
138     PointCopy(node->sum, first_point, d_);
139     if (last_index != first_index)
140       PointScale(node->sum, Scalar(last_index - first_index + 1), d_);
141     node->opt_cost = 0;
142     return node;
143   }
144
145   // Partition the points around the midpoint in this dimension. The partitioning is done in-place
146   // by iterating from left-to-right and right-to-left in the same way that partioning is done for
147   // quicksort.
148   Scalar split_pos = node->median[split_d];
149   int i1 = first_index, i2 = last_index, size1 = 0;
150   while (i1 <= i2) {
151     bool is_i1_good = (points[point_indices_[i1]*d_ + split_d] < split_pos);
152     bool is_i2_good = (points[point_indices_[i2]*d_ + split_d] >= split_pos);
153     if (!is_i1_good && !is_i2_good) {
154       int temp = point_indices_[i1];
155       point_indices_[i1] = point_indices_[i2];
156       point_indices_[i2] = temp;
157       is_i1_good = is_i2_good = true;
158     }
159     if (is_i1_good) {
160       i1++;
161       size1++;
162     }
163     if (is_i2_good) {
164       i2--;
165     }
166   }
167
168   // Create the child nodes
169   KM_ASSERT(size1 >= 1 && size1 <= last_index - first_index);
170   node->lower_node = BuildNodes(points, first_index, first_index + size1 - 1, next_node_data);
171   node->upper_node = BuildNodes(points, first_index + size1, last_index, next_node_data);
172
173   // Calculate the new sum and opt cost
174   PointCopy(node->sum, node->lower_node->sum, d_);
175   PointAdd(node->sum, node->upper_node->sum, d_);
176   Scalar *center = PointAllocate(d_);
177   KM_ASSERT(center != 0);
178   PointCopy(center, node->sum, d_);
179   PointScale(center, Scalar(1) / node->num_points, d_);
180   node->opt_cost = GetNodeCost(node->lower_node, center) + GetNodeCost(node->upper_node, center);
181   PointFree(center);
182   return node;
183 }
184
185 // Returns the total contribution of all points in the given kd-tree node, assuming they are all
186 // assigned to a center at the given location. We need to return:
187 //
188 //   sum_{x \in node} ||x - center||^2.
189 //
190 // If c denotes the center of mass of the points in this node and n denotes the number of points in
191 // it, then this quantity is given by
192 //
193 //   n * ||c - center||^2 + sum_{x \in node} ||x - c||^2
194 //
195 // The sum is precomputed for each node as opt_cost. This formula follows from expanding both sides
196 // as dot products. See Kanungo/Mount for more info.
197 Scalar KmTree::GetNodeCost(const Node *node, Scalar *center) const {
198   Scalar dist_sq = 0;
199   for (int i = 0; i < d_; i++) {
200     Scalar x = (node->sum[i] / node->num_points) - center[i];
201     dist_sq += x*x;
202   }
203   return node->opt_cost + node->num_points * dist_sq;
204 }
205
206 // Helper functions for DoKMeans step
207 // ==================================
208
209 // A recursive version of DoKMeansStep. This determines which clusters all points that are rooted
210 // node will be assigned to, and updates sums, counts and assignment (if not null) accordingly.
211 // candidates maintains the set of cluster indices which could possibly be the closest clusters
212 // for points in this subtree.
213 Scalar KmTree::DoKMeansStepAtNode(const Node *node, int k, int *candidates, Scalar *centers,
214                                   Scalar *sums, int *counts, int *assignment) const {
215   // Determine which center the node center is closest to
216   Scalar min_dist_sq = PointDistSq(node->median, centers + candidates[0]*d_, d_);
217   int closest_i = candidates[0];
218   for (int i = 1; i < k; i++) {
219     Scalar dist_sq = PointDistSq(node->median, centers + candidates[i]*d_, d_);
220     if (dist_sq < min_dist_sq) {
221       min_dist_sq = dist_sq;
222       closest_i = candidates[i];
223     }
224   }
225
226   // If this is a non-leaf node, recurse if necessary
227   if (node->lower_node != 0) {
228     // Build the new list of candidates
229     int new_k = 0;
230     int *new_candidates = (int*)malloc(k * sizeof(int));
231     KM_ASSERT(new_candidates != 0);
232     for (int i = 0; i < k; i++)
233     if (!ShouldBePruned(node->median, node->radius, centers, closest_i, candidates[i]))
234       new_candidates[new_k++] = candidates[i];
235
236     // Recurse if there's at least two
237     if (new_k > 1) {
238       Scalar result = DoKMeansStepAtNode(node->lower_node, new_k, new_candidates, centers,
239                                          sums, counts, assignment) +
240                       DoKMeansStepAtNode(node->upper_node, new_k, new_candidates, centers,
241                                          sums, counts, assignment);
242       free(new_candidates);
243       return result;
244     } else {
245       free(new_candidates);
246     }
247   }
248
249   // Assigns all points within this node to a single center
250   PointAdd(sums + closest_i*d_, node->sum, d_);
251   counts[closest_i] += node->num_points;
252   if (assignment != 0) {
253     for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
254       assignment[point_indices_[i]] = closest_i;
255   }
256   return GetNodeCost(node, centers + closest_i*d_);
257 }
258
259 // Determines whether every point in the box is closer to centers[best_index] than to
260 // centers[test_index].
261 //
262 // If x is a point, c_0 = centers[best_index], c = centers[test_index], then:
263 //       (x-c).(x-c) < (x-c_0).(x-c_0)
264 //   <=> (c-c_0).(c-c_0) < 2(x-c_0).(c-c_0)
265 //
266 // The right-hand side is maximized for a vertex of the box where for each dimension, we choose
267 // the low or high value based on the sign of x-c_0 in that dimension.
268 bool KmTree::ShouldBePruned(Scalar *box_median, Scalar *box_radius, Scalar *centers,
269                             int best_index, int test_index) const {
270   if (best_index == test_index)
271     return false;
272   
273   Scalar *best = centers + best_index*d_;
274   Scalar *test = centers + test_index*d_;
275   Scalar lhs = 0, rhs = 0;
276   for (int i = 0; i < d_; i++) {
277     Scalar component = test[i] - best[i];
278     lhs += component * component;
279     if (component > 0)
280       rhs += (box_median[i] + box_radius[i] - best[i]) * component;
281     else
282       rhs += (box_median[i] - box_radius[i] - best[i]) * component;
283   }
284   return (lhs >= 2*rhs);
285 }
286
287 Scalar KmTree::SeedKMeansPlusPlus(int k, Scalar *centers) const {
288   Scalar *dist_sq = (Scalar*)malloc(n_ * sizeof(Scalar));
289   KM_ASSERT(dist_sq != 0);
290
291   // Choose an initial center uniformly at random
292   SeedKmppSetClusterIndex(top_node_, 0);
293   int i = GetRandom(n_);
294   memcpy(centers, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
295   Scalar total_cost = 0;
296   for (int j = 0; j < n_; j++) {
297     dist_sq[j] = PointDistSq(points_ + point_indices_[j]*d_, centers, d_);
298     total_cost += dist_sq[j];
299   }
300
301   // Repeatedly choose more centers
302   for (int new_cluster = 1; new_cluster < k; new_cluster++) {
303     while (1) {
304       Scalar cutoff = (rand() / Scalar(RAND_MAX)) * total_cost;
305       Scalar cur_cost = 0;
306       for (i = 0; i < n_; i++) {
307         cur_cost += dist_sq[i];
308         if (cur_cost >= cutoff)
309           break;
310       }
311       if (i < n_)
312         break;
313     }
314     memcpy(centers + new_cluster*d_, points_ + point_indices_[i]*d_, d_*sizeof(Scalar));
315     total_cost = SeedKmppUpdateAssignment(top_node_, new_cluster, centers, dist_sq);
316   }
317
318   // Clean up and return
319   free(dist_sq);
320   return total_cost;
321 }
322
323 // Helper functions for SeedKMeansPlusPlus
324 // =======================================
325
326 // Sets kmpp_cluster_index to 0 for all nodes
327 void KmTree::SeedKmppSetClusterIndex(const Node *node, int value) const {
328   node->kmpp_cluster_index = value;
329   if (node->lower_node != 0) {
330     SeedKmppSetClusterIndex(node->lower_node, value);
331     SeedKmppSetClusterIndex(node->upper_node, value);
332   }
333 }
334
335 Scalar KmTree::SeedKmppUpdateAssignment(const Node *node, int new_cluster, Scalar *centers,
336                                         Scalar *dist_sq) const {
337   // See if we can assign all points in this node to one cluster
338   if (node->kmpp_cluster_index >= 0) {
339     if (ShouldBePruned(node->median, node->radius, centers, node->kmpp_cluster_index, new_cluster))
340       return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
341     if (ShouldBePruned(node->median, node->radius, centers, new_cluster,
342                        node->kmpp_cluster_index)) {
343       SeedKmppSetClusterIndex(node, new_cluster);
344       for (int i = node->first_point_index; i < node->first_point_index + node->num_points; i++)
345         dist_sq[i] = PointDistSq(points_ + point_indices_[i]*d_, centers + new_cluster*d_, d_);
346       return GetNodeCost(node, centers + new_cluster*d_);
347     }
348     
349     // It may be that the a leaf-node point is equidistant from the new center or old
350     if (node->lower_node == 0)
351       return GetNodeCost(node, centers + node->kmpp_cluster_index*d_);
352   }
353
354   // Recurse
355   Scalar cost = SeedKmppUpdateAssignment(node->lower_node, new_cluster, centers, dist_sq) +
356                 SeedKmppUpdateAssignment(node->upper_node, new_cluster, centers, dist_sq);
357   int i1 = node->lower_node->kmpp_cluster_index, i2 = node->upper_node->kmpp_cluster_index;
358   if (i1 == i2 && i1 != -1)
359     node->kmpp_cluster_index = i1;
360   else
361     node->kmpp_cluster_index = -1;
362   return cost;
363 }