--- /dev/null
+#include "muscle.h"\r
+#include "cluster.h"\r
+#include "distfunc.h"\r
+\r
+static inline float Min(float d1, float d2)\r
+ {\r
+ return d1 < d2 ? d1 : d2;\r
+ }\r
+\r
+static inline float Max(float d1, float d2)\r
+ {\r
+ return d1 > d2 ? d1 : d2;\r
+ }\r
+\r
+static inline float Mean(float d1, float d2)\r
+ {\r
+ return (float) ((d1 + d2)/2.0);\r
+ }\r
+\r
+#if _DEBUG\r
+void ClusterTree::Validate(unsigned uNodeCount)\r
+ {\r
+ unsigned n;\r
+ ClusterNode *pNode;\r
+ unsigned uDisjointListCount = 0;\r
+ for (pNode = m_ptrDisjoints; pNode; pNode = pNode->GetNextDisjoint())\r
+ {\r
+ ClusterNode *pPrev = pNode->GetPrevDisjoint();\r
+ ClusterNode *pNext = pNode->GetNextDisjoint();\r
+ if (0 != pPrev)\r
+ {\r
+ if (pPrev->GetNextDisjoint() != pNode)\r
+ {\r
+ Log("Prev->This mismatch, prev=\n");\r
+ pPrev->LogMe();\r
+ Log("This=\n");\r
+ pNode->LogMe();\r
+ Quit("ClusterTree::Validate()");\r
+ }\r
+ }\r
+ else\r
+ {\r
+ if (pNode != m_ptrDisjoints)\r
+ {\r
+ Log("[%u]->prev = 0 but != m_ptrDisjoints=%d\n",\r
+ pNode->GetIndex(),\r
+ m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);\r
+ pNode->LogMe();\r
+ Quit("ClusterTree::Validate()");\r
+ }\r
+ }\r
+ if (0 != pNext)\r
+ {\r
+ if (pNext->GetPrevDisjoint() != pNode)\r
+ {\r
+ Log("Next->This mismatch, next=\n");\r
+ pNext->LogMe();\r
+ Log("This=\n");\r
+ pNode->LogMe();\r
+ Quit("ClusterTree::Validate()");\r
+ }\r
+ }\r
+ ++uDisjointListCount;\r
+ if (uDisjointListCount > m_uNodeCount)\r
+ Quit("Loop in disjoint list");\r
+ }\r
+\r
+ unsigned uParentlessNodeCount = 0;\r
+ for (n = 0; n < uNodeCount; ++n)\r
+ if (0 == m_Nodes[n].GetParent())\r
+ ++uParentlessNodeCount;\r
+ \r
+ if (uDisjointListCount != uParentlessNodeCount)\r
+ Quit("Disjoints = %u Parentless = %u\n", uDisjointListCount,\r
+ uParentlessNodeCount);\r
+ }\r
+#else // !_DEBUG\r
+#define Validate(uNodeCount) // empty\r
+#endif\r
+\r
+void ClusterNode::LogMe() const\r
+ {\r
+ unsigned uClusterSize = GetClusterSize();\r
+ Log("[%02u] w=%5.3f CW=%5.3f LBW=%5.3f RBW=%5.3f LWT=%5.3f RWT=%5.3f L=%02d R=%02d P=%02d NxDj=%02d PvDj=%02d Sz=%02d {",\r
+ m_uIndex,\r
+ m_dWeight,\r
+ GetClusterWeight(),\r
+ GetLeftBranchWeight(),\r
+ GetRightBranchWeight(),\r
+ GetLeftWeight(),\r
+ GetRightWeight(),\r
+ m_ptrLeft ? m_ptrLeft->GetIndex() : 0xffffffff,\r
+ m_ptrRight ? m_ptrRight->GetIndex() : 0xffffffff,\r
+ m_ptrParent ? m_ptrParent->GetIndex() : 0xffffffff,\r
+ m_ptrNextDisjoint ? m_ptrNextDisjoint->GetIndex() : 0xffffffff,\r
+ m_ptrPrevDisjoint ? m_ptrPrevDisjoint->GetIndex() : 0xffffffff,\r
+ uClusterSize);\r
+ for (unsigned i = 0; i < uClusterSize; ++i)\r
+ Log(" %u", GetClusterLeaf(i)->GetIndex());\r
+ Log(" }\n");\r
+ }\r
+\r
+// How many leaves in the sub-tree under this node?\r
+unsigned ClusterNode::GetClusterSize() const\r
+ {\r
+ unsigned uLeafCount = 0;\r
+\r
+ if (0 == m_ptrLeft && 0 == m_ptrRight)\r
+ return 1;\r
+\r
+ if (0 != m_ptrLeft)\r
+ uLeafCount += m_ptrLeft->GetClusterSize();\r
+ if (0 != m_ptrRight)\r
+ uLeafCount += m_ptrRight->GetClusterSize();\r
+ assert(uLeafCount > 0);\r
+ return uLeafCount;\r
+ }\r
+\r
+double ClusterNode::GetClusterWeight() const\r
+ {\r
+ double dWeight = 0.0;\r
+ if (0 != m_ptrLeft)\r
+ dWeight += m_ptrLeft->GetClusterWeight();\r
+ if (0 != m_ptrRight)\r
+ dWeight += m_ptrRight->GetClusterWeight();\r
+ return dWeight + GetWeight();\r
+ }\r
+\r
+double ClusterNode::GetLeftBranchWeight() const\r
+ {\r
+ const ClusterNode *ptrLeft = GetLeft();\r
+ if (0 == ptrLeft)\r
+ return 0.0;\r
+\r
+ return GetWeight() - ptrLeft->GetWeight();\r
+ }\r
+\r
+double ClusterNode::GetRightBranchWeight() const\r
+ {\r
+ const ClusterNode *ptrRight = GetRight();\r
+ if (0 == ptrRight)\r
+ return 0.0;\r
+\r
+ return GetWeight() - ptrRight->GetWeight();\r
+ }\r
+\r
+double ClusterNode::GetRightWeight() const\r
+ {\r
+ const ClusterNode *ptrRight = GetRight();\r
+ if (0 == ptrRight)\r
+ return 0.0;\r
+ return ptrRight->GetClusterWeight() + GetWeight();\r
+ }\r
+\r
+double ClusterNode::GetLeftWeight() const\r
+ {\r
+ const ClusterNode *ptrLeft = GetLeft();\r
+ if (0 == ptrLeft)\r
+ return 0.0;\r
+ return ptrLeft->GetClusterWeight() + GetWeight();\r
+ }\r
+\r
+// Return n'th leaf in the sub-tree under this node.\r
+const ClusterNode *ClusterNode::GetClusterLeaf(unsigned uLeafIndex) const\r
+ {\r
+ if (0 != m_ptrLeft)\r
+ {\r
+ if (0 == m_ptrRight)\r
+ return this;\r
+\r
+ unsigned uLeftLeafCount = m_ptrLeft->GetClusterSize();\r
+\r
+ if (uLeafIndex < uLeftLeafCount)\r
+ return m_ptrLeft->GetClusterLeaf(uLeafIndex);\r
+\r
+ assert(uLeafIndex >= uLeftLeafCount);\r
+ return m_ptrRight->GetClusterLeaf(uLeafIndex - uLeftLeafCount);\r
+ }\r
+ if (0 == m_ptrRight)\r
+ return this;\r
+ return m_ptrRight->GetClusterLeaf(uLeafIndex);\r
+ }\r
+\r
+void ClusterTree::DeleteFromDisjoints(ClusterNode *ptrNode)\r
+ {\r
+ ClusterNode *ptrPrev = ptrNode->GetPrevDisjoint();\r
+ ClusterNode *ptrNext = ptrNode->GetNextDisjoint();\r
+\r
+ if (0 != ptrPrev)\r
+ ptrPrev->SetNextDisjoint(ptrNext);\r
+ else\r
+ m_ptrDisjoints = ptrNext;\r
+\r
+ if (0 != ptrNext)\r
+ ptrNext->SetPrevDisjoint(ptrPrev);\r
+\r
+#if _DEBUG\r
+// not algorithmically necessary, but improves clarity\r
+// and supports Validate().\r
+ ptrNode->SetPrevDisjoint(0);\r
+ ptrNode->SetNextDisjoint(0);\r
+#endif\r
+ }\r
+\r
+void ClusterTree::AddToDisjoints(ClusterNode *ptrNode)\r
+ {\r
+ ptrNode->SetNextDisjoint(m_ptrDisjoints);\r
+ ptrNode->SetPrevDisjoint(0);\r
+ if (0 != m_ptrDisjoints)\r
+ m_ptrDisjoints->SetPrevDisjoint(ptrNode);\r
+ m_ptrDisjoints = ptrNode;\r
+ }\r
+\r
+ClusterTree::ClusterTree()\r
+ {\r
+ m_ptrDisjoints = 0;\r
+ m_Nodes = 0;\r
+ m_uNodeCount = 0;\r
+ }\r
+\r
+ClusterTree::~ClusterTree()\r
+ {\r
+ delete[] m_Nodes;\r
+ }\r
+\r
+void ClusterTree::LogMe() const\r
+ {\r
+ Log("Disjoints=%d\n", m_ptrDisjoints ? m_ptrDisjoints->GetIndex() : 0xffffffff);\r
+ for (unsigned i = 0; i < m_uNodeCount; ++i)\r
+ {\r
+ m_Nodes[i].LogMe();\r
+ }\r
+ }\r
+\r
+ClusterNode *ClusterTree::GetRoot() const\r
+ {\r
+ return &m_Nodes[m_uNodeCount - 1];\r
+ }\r
+\r
+// This is the UPGMA algorithm as described in Durbin et al. p166.\r
+void ClusterTree::Create(const DistFunc &Dist)\r
+ {\r
+ unsigned i;\r
+ m_uLeafCount = Dist.GetCount();\r
+ m_uNodeCount = 2*m_uLeafCount - 1;\r
+\r
+ delete[] m_Nodes;\r
+ m_Nodes = new ClusterNode[m_uNodeCount];\r
+\r
+ for (i = 0; i < m_uNodeCount; ++i)\r
+ m_Nodes[i].SetIndex(i);\r
+\r
+ for (i = 0; i < m_uLeafCount - 1; ++i)\r
+ m_Nodes[i].SetNextDisjoint(&m_Nodes[i+1]);\r
+\r
+ for (i = 1; i < m_uLeafCount; ++i)\r
+ m_Nodes[i].SetPrevDisjoint(&m_Nodes[i-1]);\r
+ \r
+ m_ptrDisjoints = &m_Nodes[0];\r
+\r
+// Log("Initial state\n");\r
+// LogMe();\r
+// Log("\n");\r
+\r
+ DistFunc ClusterDist;\r
+ ClusterDist.SetCount(m_uNodeCount);\r
+ double dMaxDist = 0.0;\r
+ for (i = 0; i < m_uLeafCount; ++i)\r
+ for (unsigned j = 0; j < m_uLeafCount; ++j)\r
+ {\r
+ float dDist = Dist.GetDist(i, j);\r
+ ClusterDist.SetDist(i, j, dDist);\r
+ }\r
+\r
+ Validate(m_uLeafCount);\r
+\r
+// Iteration. N-1 joins needed to create a binary tree from N leaves.\r
+ for (unsigned uJoinIndex = m_uLeafCount; uJoinIndex < m_uNodeCount;\r
+ ++uJoinIndex)\r
+ {\r
+ // Find closest pair of clusters\r
+ unsigned uIndexClosest1;\r
+ unsigned uIndexClosest2;\r
+ bool bFound = false;\r
+ double dDistClosest = 9e99;\r
+ for (ClusterNode *ptrNode1 = m_ptrDisjoints; ptrNode1;\r
+ ptrNode1 = ptrNode1->GetNextDisjoint())\r
+ {\r
+ for (ClusterNode *ptrNode2 = ptrNode1->GetNextDisjoint(); ptrNode2;\r
+ ptrNode2 = ptrNode2->GetNextDisjoint())\r
+ {\r
+ unsigned i1 = ptrNode1->GetIndex();\r
+ unsigned i2 = ptrNode2->GetIndex();\r
+ double dDist = ClusterDist.GetDist(i1, i2);\r
+ if (dDist < dDistClosest)\r
+ {\r
+ bFound = true;\r
+ dDistClosest = dDist;\r
+ uIndexClosest1 = i1;\r
+ uIndexClosest2 = i2;\r
+ }\r
+ }\r
+ }\r
+ assert(bFound);\r
+\r
+ ClusterNode &Join = m_Nodes[uJoinIndex];\r
+ ClusterNode &Child1 = m_Nodes[uIndexClosest1];\r
+ ClusterNode &Child2 = m_Nodes[uIndexClosest2];\r
+\r
+ Join.SetLeft(&Child1);\r
+ Join.SetRight(&Child2);\r
+ Join.SetWeight(dDistClosest);\r
+\r
+ Child1.SetParent(&Join);\r
+ Child2.SetParent(&Join);\r
+\r
+ DeleteFromDisjoints(&Child1);\r
+ DeleteFromDisjoints(&Child2);\r
+ AddToDisjoints(&Join);\r
+\r
+// Log("After join %d %d\n", uIndexClosest1, uIndexClosest2);\r
+// LogMe();\r
+\r
+ // Calculate distance of every remaining disjoint cluster to the\r
+ // new cluster created by the join\r
+ for (ClusterNode *ptrNode = m_ptrDisjoints; ptrNode;\r
+ ptrNode = ptrNode->GetNextDisjoint())\r
+ {\r
+ unsigned uNodeIndex = ptrNode->GetIndex();\r
+ float dDist1 = ClusterDist.GetDist(uNodeIndex, uIndexClosest1);\r
+ float dDist2 = ClusterDist.GetDist(uNodeIndex, uIndexClosest2);\r
+ float dDist = Min(dDist1, dDist2);\r
+ ClusterDist.SetDist(uJoinIndex, uNodeIndex, dDist);\r
+ }\r
+ Validate(uJoinIndex+1);\r
+ }\r
+ GetRoot()->GetClusterWeight();\r
+// LogMe();\r
+ }\r