3 #include "distfunc.h"
\r
5 static inline float Min(float d1, float d2)
\r
7 return d1 < d2 ? d1 : d2;
\r
10 static inline float Max(float d1, float d2)
\r
12 return d1 > d2 ? d1 : d2;
\r
15 static inline float Mean(float d1, float d2)
\r
17 return (float) ((d1 + d2)/2.0);
\r
21 void ClusterTree::Validate(unsigned uNodeCount)
\r
25 unsigned uDisjointListCount = 0;
\r
26 for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint())
\r
28 ClusterNode *pPrev = pNode->GetPrevDisjoint();
\r
29 ClusterNode *pNext = pNode->GetNextDisjoint();
\r
32 if (pPrev->GetNextDisjoint() != pNode)
\r
34 Log("Prev->This mismatch, prev=\n");
\r
38 Quit("ClusterTree::Validate()");
\r
43 if (pNode != m_ptrDisjoints)
\r
45 Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n",
\r
47 m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
\r
49 Quit("ClusterTree::Validate()");
\r
54 if (pNext->GetPrevDisjoint() != pNode)
\r
56 Log("Next->This mismatch, next=\n");
\r
60 Quit("ClusterTree::Validate()");
\r
63 ++uDisjointListCount;
\r
64 if (uDisjointListCount > m_uNodeCount)
\r
65 Quit("Loop in disjoint list");
\r
68 unsigned uParentlessNodeCount = 0;
\r
69 for (n = 0; n < uNodeCount; ++n)
\r
70 if (0 == m_Nodes[n].GetParent())
\r
71 ++uParentlessNodeCount;
\r
73 if (uDisjointListCount != uParentlessNodeCount)
\r
74 Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount,
\r
75 uParentlessNodeCount);
\r
78 #define Validate(uNodeCount) // empty
\r
81 void ClusterNode::LogMe() const
\r
83 unsigned uClusterSize = GetClusterSize();
\r
84 Log("[%02u] w=%5.3f CW=%5.3f LBW=%5.3f RBW=%5.3f LWT=%5.3f RWT=%5.3f L=%02d R=%02d P=%02d NxDj=%02d PvDj=%02d Sz=%02d {",
\r
88 GetLeftBranchWeight(),
\r
89 GetRightBranchWeight(),
\r
92 m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff,
\r
93 m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff,
\r
94 m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff,
\r
95 m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff,
\r
96 m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff,
\r
98 for (unsigned i = 0; i < uClusterSize; ++i)
\r
99 Log(" %u", GetClusterLeaf(i)->GetIndex());
\r
103 // How many leaves in the sub-tree under this node?
\r
104 unsigned ClusterNode::GetClusterSize() const
\r
106 unsigned uLeafCount = 0;
\r
108 if (0 == m_ptrLeft && 0 == m_ptrRight)
\r
111 if (0 != m_ptrLeft)
\r
112 uLeafCount += m_ptrLeft->GetClusterSize();
\r
113 if (0 != m_ptrRight)
\r
114 uLeafCount += m_ptrRight->GetClusterSize();
\r
115 assert(uLeafCount > 0);
\r
119 double ClusterNode::GetClusterWeight() const
\r
121 double dWeight = 0.0;
\r
122 if (0 != m_ptrLeft)
\r
123 dWeight += m_ptrLeft->GetClusterWeight();
\r
124 if (0 != m_ptrRight)
\r
125 dWeight += m_ptrRight->GetClusterWeight();
\r
126 return dWeight + GetWeight();
\r
129 double ClusterNode::GetLeftBranchWeight() const
\r
131 const ClusterNode *ptrLeft = GetLeft();
\r
135 return GetWeight() - ptrLeft->GetWeight();
\r
138 double ClusterNode::GetRightBranchWeight() const
\r
140 const ClusterNode *ptrRight = GetRight();
\r
144 return GetWeight() - ptrRight->GetWeight();
\r
147 double ClusterNode::GetRightWeight() const
\r
149 const ClusterNode *ptrRight = GetRight();
\r
152 return ptrRight->GetClusterWeight() + GetWeight();
\r
155 double ClusterNode::GetLeftWeight() const
\r
157 const ClusterNode *ptrLeft = GetLeft();
\r
160 return ptrLeft->GetClusterWeight() + GetWeight();
\r
163 // Return n'th leaf in the sub-tree under this node.
\r
164 const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const
\r
166 if (0 != m_ptrLeft)
\r
168 if (0 == m_ptrRight)
\r
171 unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize();
\r
173 if (uLeafIndex < uLeftLeafCount)
\r
174 return m_ptrLeft->GetClusterLeaf(uLeafIndex);
\r
176 assert(uLeafIndex >= uLeftLeafCount);
\r
177 return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount);
\r
179 if (0 == m_ptrRight)
\r
181 return m_ptrRight->GetClusterLeaf(uLeafIndex);
\r
184 void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode)
\r
186 ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint();
\r
187 ClusterNode *ptrNext = ptrNode->GetNextDisjoint();
\r
190 ptrPrev->SetNextDisjoint(ptrNext);
\r
192 m_ptrDisjoints = ptrNext;
\r
195 ptrNext->SetPrevDisjoint(ptrPrev);
\r
198 // not algorithmically necessary, but improves clarity
\r
199 // and supports Validate().
\r
200 ptrNode->SetPrevDisjoint(0);
\r
201 ptrNode->SetNextDisjoint(0);
\r
205 void ClusterTree::AddToDisjoints(ClusterNode *ptrNode)
\r
207 ptrNode->SetNextDisjoint(m_ptrDisjoints);
\r
208 ptrNode->SetPrevDisjoint(0);
\r
209 if (0 != m_ptrDisjoints)
\r
210 m_ptrDisjoints->SetPrevDisjoint(ptrNode);
\r
211 m_ptrDisjoints = ptrNode;
\r
214 ClusterTree::ClusterTree()
\r
216 m_ptrDisjoints = 0;
\r
221 ClusterTree::~ClusterTree()
\r
226 void ClusterTree::LogMe() const
\r
228 Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);
\r
229 for (unsigned i = 0; i < m_uNodeCount; ++i)
\r
231 m_Nodes[i].LogMe();
\r
235 ClusterNode *ClusterTree::GetRoot() const
\r
237 return &m_Nodes[m_uNodeCount - 1];
\r
240 // This is the UPGMA algorithm as described in Durbin et al. p166.
\r
241 void ClusterTree::Create(const DistFunc &Dist)
\r
244 m_uLeafCount = Dist.GetCount();
\r
245 m_uNodeCount = 2*m_uLeafCount - 1;
\r
248 m_Nodes = new ClusterNode[m_uNodeCount];
\r
250 for (i = 0; i < m_uNodeCount; ++i)
\r
251 m_Nodes[i].SetIndex(i);
\r
253 for (i = 0; i < m_uLeafCount - 1; ++i)
\r
254 m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]);
\r
256 for (i = 1; i < m_uLeafCount; ++i)
\r
257 m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]);
\r
259 m_ptrDisjoints = &m_Nodes[0];
\r
261 // Log("Initial state\n");
\r
265 DistFunc ClusterDist;
\r
266 ClusterDist.SetCount(m_uNodeCount);
\r
267 double dMaxDist = 0.0;
\r
268 for (i = 0; i < m_uLeafCount; ++i)
\r
269 for (unsigned j = 0; j < m_uLeafCount; ++j)
\r
271 float dDist = Dist.GetDist(i, j);
\r
272 ClusterDist.SetDist(i, j, dDist);
\r
275 Validate(m_uLeafCount);
\r
277 // Iteration. N-1 joins needed to create a binary tree from N leaves.
\r
278 for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount;
\r
281 // Find closest pair of clusters
\r
282 unsigned uIndexClosest1;
\r
283 unsigned uIndexClosest2;
\r
284 bool bFound = false;
\r
285 double dDistClosest = 9e99;
\r
286 for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1;
\r
287 ptrNode1 = ptrNode1->GetNextDisjoint())
\r
289 for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2;
\r
290 ptrNode2 = ptrNode2->GetNextDisjoint())
\r
292 unsigned i1 = ptrNode1->GetIndex();
\r
293 unsigned i2 = ptrNode2->GetIndex();
\r
294 double dDist = ClusterDist.GetDist(i1, i2);
\r
295 if (dDist < dDistClosest)
\r
298 dDistClosest = dDist;
\r
299 uIndexClosest1 = i1;
\r
300 uIndexClosest2 = i2;
\r
306 ClusterNode &Join = m_Nodes[uJoinIndex];
\r
307 ClusterNode &Child1 = m_Nodes[uIndexClosest1];
\r
308 ClusterNode &Child2 = m_Nodes[uIndexClosest2];
\r
310 Join.SetLeft(&Child1);
\r
311 Join.SetRight(&Child2);
\r
312 Join.SetWeight(dDistClosest);
\r
314 Child1.SetParent(&Join);
\r
315 Child2.SetParent(&Join);
\r
317 DeleteFromDisjoints(&Child1);
\r
318 DeleteFromDisjoints(&Child2);
\r
319 AddToDisjoints(&Join);
\r
321 // Log("After join %d %d\n", uIndexClosest1, uIndexClosest2);
\r
324 // Calculate distance of every remaining disjoint cluster to the
\r
325 // new cluster created by the join
\r
326 for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode;
\r
327 ptrNode = ptrNode->GetNextDisjoint())
\r
329 unsigned uNodeIndex = ptrNode->GetIndex();
\r
330 float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1);
\r
331 float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2);
\r
332 float dDist = Min(dDist1, dDist2);
\r
333 ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist);
\r
335 Validate(uJoinIndex+1);
\r
337 GetRoot()->GetClusterWeight();
\r