Merge branch 'develop' of https://source.jalview.org/git/jalview.git into develop
[jalview.git] / src / jalview / analysis / AverageDistanceTree.java
diff --git a/src/jalview/analysis/AverageDistanceTree.java b/src/jalview/analysis/AverageDistanceTree.java
new file mode 100644 (file)
index 0000000..907109e
--- /dev/null
@@ -0,0 +1,121 @@
+package jalview.analysis;
+
+import jalview.api.analysis.ScoreModelI;
+import jalview.api.analysis.SimilarityParamsI;
+import jalview.datamodel.SequenceNode;
+import jalview.viewmodel.AlignmentViewport;
+
+/**
+ * This class implements distance calculations used in constructing a Average
+ * Distance tree (also known as UPGMA)
+ */
+public class AverageDistanceTree extends TreeBuilder
+{
+  /**
+   * Constructor
+   * 
+   * @param av
+   * @param sm
+   * @param scoreParameters
+   */
+  public AverageDistanceTree(AlignmentViewport av, ScoreModelI sm,
+          SimilarityParamsI scoreParameters)
+  {
+    super(av, sm, scoreParameters);
+  }
+
+  /**
+   * Calculates and saves the distance between the combination of cluster(i) and
+   * cluster(j) and all other clusters. An average of the distances from
+   * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
+   * cluster.
+   * 
+   * @param i
+   * @param j
+   */
+  @Override
+  protected void findClusterDistance(int i, int j)
+  {
+    int noi = clusters.elementAt(i).cardinality();
+    int noj = clusters.elementAt(j).cardinality();
+
+    // New distances from cluster i to others
+    double[] newdist = new double[noseqs];
+
+    for (int l = 0; l < noseqs; l++)
+    {
+      if ((l != i) && (l != j))
+      {
+        newdist[l] = ((distances.getValue(i, l) * noi) + (distances
+                .getValue(j, l) * noj)) / (noi + noj);
+      }
+      else
+      {
+        newdist[l] = 0;
+      }
+    }
+
+    for (int ii = 0; ii < noseqs; ii++)
+    {
+      distances.setValue(i, ii, newdist[ii]);
+      distances.setValue(ii, i, newdist[ii]);
+    }
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  protected double findMinDistance()
+  {
+    double min = Double.MAX_VALUE;
+
+    for (int i = 0; i < (noseqs - 1); i++)
+    {
+      for (int j = i + 1; j < noseqs; j++)
+      {
+        if (!done.get(i) && !done.get(j))
+        {
+          if (distances.getValue(i, j) < min)
+          {
+            mini = i;
+            minj = j;
+
+            min = distances.getValue(i, j);
+          }
+        }
+      }
+    }
+    return min;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  protected void findNewDistances(SequenceNode nodei, SequenceNode nodej,
+          double dist)
+  {
+    double ih = 0;
+    double jh = 0;
+
+    SequenceNode sni = nodei;
+    SequenceNode snj = nodej;
+
+    while (sni != null)
+    {
+      ih = ih + sni.dist;
+      sni = (SequenceNode) sni.left();
+    }
+
+    while (snj != null)
+    {
+      jh = jh + snj.dist;
+      snj = (SequenceNode) snj.left();
+    }
+
+    nodei.dist = ((dist / 2) - ih);
+    nodej.dist = ((dist / 2) - jh);
+  }
+
+}