JAL-4134 use average distance between PAE elements as distance for tree calculation
authorJames Procter <j.procter@dundee.ac.uk>
Mon, 27 Feb 2023 17:57:44 +0000 (17:57 +0000)
committerJames Procter <j.procter@dundee.ac.uk>
Mon, 27 Feb 2023 17:57:44 +0000 (17:57 +0000)
src/jalview/analysis/AverageDistanceEngine.java

index ed6f861..e6a763b 100644 (file)
@@ -39,71 +39,113 @@ import jalview.viewmodel.AlignmentViewport;
 public class AverageDistanceEngine extends TreeEngine
 {
   ContactMatrixI cm;
+
   AlignmentViewport av;
+
   AlignmentAnnotation aa;
+
   /**
-   * compute cosine distance matrix for a given contact matrix and create a UPGMA tree
+   * compute cosine distance matrix for a given contact matrix and create a
+   * UPGMA tree
+   * 
    * @param cm
    */
-  public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa, ContactMatrixI cm)
+  public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa,
+          ContactMatrixI cm)
   {
-    this.av =av;
+    this.av = av;
     this.aa = aa;
-    this.cm=cm;
+    this.cm = cm;
     calculate(cm);
 
   }
+
+  // 0 - normalised dot product
+  // 1 - L1 - ie (abs(v_1-v_2)/dim(v))
+  // L1 is more rational - since can reason about value of difference,
+  // normalised dot product might give cleaner clusters, but more difficult to
+  // understand.
+
+  int mode = 1;
+
   public void calculate(ContactMatrixI cm)
   {
     this.cm = cm;
     node = new Vector<BinaryNode>();
     clusters = new Vector<BitSet>();
     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
-    noseqs=cm.getWidth();
-    done  = new BitSet();
-    double moduli[]=new double[cm.getWidth()];
-    
-    
-    for (int i=0;i<cm.getWidth();i++)
+    noseqs = cm.getWidth();
+    done = new BitSet();
+    double moduli[] = new double[cm.getWidth()];
+    double max;
+    if (mode == 0)
+    {
+      max = 1;
+    }
+    else
+    {
+      max = cm.getMax() * cm.getMax();
+    }
+
+    for (int i = 0; i < cm.getWidth(); i++)
     {
       // init the tree engine node for this column
       BinaryNode cnode = new BinaryNode();
       cnode.setElement(Integer.valueOf(i));
-      cnode.setName("c"+i);
+      cnode.setName("c" + i);
       node.addElement(cnode);
       BitSet bs = new BitSet();
       bs.set(i);
       clusters.addElement(bs);
 
       // compute distance matrix element
-      ContactListI ith=cm.getContactList(i);
-      
-      for (int j=0;j<i;j++)
+      ContactListI ith = cm.getContactList(i);
+
+      for (int j = 0; j < i; j++)
       {
-        distances.setValue(i,i,0);
+        distances.setValue(i, i, 0);
         ContactListI jth = cm.getContactList(j);
-        double prd=0;
-        for (int indx=0;indx<cm.getHeight();indx++)
+        double prd = 0;
+        for (int indx = 0; indx < cm.getHeight(); indx++)
+        {
+          if (mode == 0)
+          {
+            if (j == 0)
+            {
+              moduli[i] += ith.getContactAt(indx) * ith.getContactAt(indx);
+            }
+            prd += ith.getContactAt(indx) * jth.getContactAt(indx);
+          }
+          else
+          {
+            prd += Math
+                    .abs(ith.getContactAt(indx) - jth.getContactAt(indx));
+          }
+        }
+        if (mode == 0)
         {
-          if (j==0)
+          if (j == 0)
           {
-            moduli[i]+=ith.getContactAt(indx)*ith.getContactAt(indx);
+            moduli[i] = Math.sqrt(moduli[i]);
           }
-          prd+=ith.getContactAt(indx)*jth.getContactAt(indx);
+          prd = (moduli[i] != 0 && moduli[j] != 0)
+                  ? prd / (moduli[i] * moduli[j])
+                  : 0;
+          prd = 1 - prd;
         }
-        if (j==0)
+        else
         {
-          moduli[i]=Math.sqrt(moduli[i]);
+          prd /= cm.getHeight();
         }
-        prd=(moduli[i]!=0 && moduli[j]!=0) ? prd/(moduli[i]*moduli[j]) : 0;
-        distances.setValue(i, j, 1-prd);
-        distances.setValue(j, i, 1-prd);
+        distances.setValue(i, j, prd);
+        distances.setValue(j, i, prd);
       }
     }
 
     noClus = clusters.size();
     cluster();
   }
+
   /**
    * Calculates and saves the distance between the combination of cluster(i) and
    * cluster(j) and all other clusters. An average of the distances from
@@ -197,6 +239,7 @@ public class AverageDistanceEngine extends TreeEngine
     nodei.dist = ((dist / 2) - ih);
     nodej.dist = ((dist / 2) - jh);
   }
+
   /***
    * not the right place - OH WELL!
    */
@@ -234,8 +277,8 @@ public class AverageDistanceEngine extends TreeEngine
     }
     else
     {
-      _groupNodes(groups,  nd.left(), threshold);
-      _groupNodes(groups,  nd.right(), threshold);
+      _groupNodes(groups, nd.left(), threshold);
+      _groupNodes(groups, nd.right(), threshold);
     }
   }
 
@@ -286,7 +329,6 @@ public class AverageDistanceEngine extends TreeEngine
     return maxheight;
   }
 
-
   /**
    * Search for leaf nodes below (or at) the given node
    * 
@@ -312,8 +354,7 @@ public class AverageDistanceEngine extends TreeEngine
    * 
    * @return Vector of leaf nodes on binary tree
    */
-  Vector<BinaryNode> findLeaves(BinaryNode nd,
-          Vector<BinaryNode> leaves)
+  Vector<BinaryNode> findLeaves(BinaryNode nd, Vector<BinaryNode> leaves)
   {
     if (nd == null)
     {
@@ -333,12 +374,11 @@ public class AverageDistanceEngine extends TreeEngine
        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
        * leaves.addElement(node); }
        */
-      findLeaves( nd.left(), leaves);
-      findLeaves( nd.right(), leaves);
+      findLeaves(nd.left(), leaves);
+      findLeaves(nd.right(), leaves);
     }
 
     return leaves;
   }
 
-
 }