JAL-2599 HMMs can now be dropped onto the desktop
[jalview.git] / src / jalview / datamodel / HiddenMarkovModel.java
index e0f13d8..e74d826 100644 (file)
@@ -1,7 +1,6 @@
 package jalview.datamodel;
 
 import jalview.gui.AlignFrame;
-import jalview.schemes.ResidueProperties;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -34,7 +33,8 @@ public class HiddenMarkovModel
   // 0. Node 0 contains average emission probabilities for each symbol
   List<HMMNode> nodes = new ArrayList<>();
 
-  // contains the HMM node for each alignment column
+  // contains the HMM node for each alignment column, alignment columns start at
+  // index 0;
   Map<Integer, Integer> nodeLookup = new HashMap<>();
   
   // contains the symbol index for each symbol
@@ -125,13 +125,14 @@ public class HiddenMarkovModel
   public HiddenMarkovModel(HiddenMarkovModel hmm)
   {
     super();
-    this.fileProperties = hmm.fileProperties;
-    this.symbols = hmm.symbols;
-    this.nodes = hmm.nodes;
-    this.nodeLookup = hmm.nodeLookup;
-    this.symbolIndexLookup = hmm.symbolIndexLookup;
+    this.fileProperties = new HashMap<>(hmm.fileProperties);
+    this.symbols = new ArrayList<>(hmm.symbols);
+    this.nodes = new ArrayList<>(hmm.nodes);
+    this.nodeLookup = new HashMap<>(hmm.nodeLookup);
+    this.symbolIndexLookup = new HashMap<>(
+            hmm.symbolIndexLookup);
     this.numberOfSymbols = hmm.numberOfSymbols;
-    this.fileHeader = hmm.fileHeader;
+    this.fileHeader = new String(hmm.fileHeader);
   }
 
   /**
@@ -391,9 +392,9 @@ public class HiddenMarkovModel
       return 0d;
     }
     symbolIndex = symbolIndexLookup.get(symbol);
-    if (nodeLookup.containsKey(alignColumn + 1))
+    if (nodeLookup.containsKey(alignColumn))
     {
-      nodeIndex = nodeLookup.get(alignColumn + 1);
+      nodeIndex = nodeLookup.get(alignColumn);
       probability = getNode(nodeIndex).getMatchEmissions().get(symbolIndex);
       return probability;
     }
@@ -426,9 +427,9 @@ public class HiddenMarkovModel
       return 0d;
     }
     symbolIndex = symbolIndexLookup.get(symbol);
-    if (nodeLookup.containsKey(alignColumn + 1))
+    if (nodeLookup.containsKey(alignColumn))
     {
-      nodeIndex = nodeLookup.get(alignColumn + 1);
+      nodeIndex = nodeLookup.get(alignColumn);
       probability = getNode(nodeIndex).getInsertEmissions()
               .get(symbolIndex);
       return probability;
@@ -458,9 +459,9 @@ public class HiddenMarkovModel
     int transitionIndex;
     int nodeIndex;
     Double probability;
-    if (nodeLookup.containsKey(alignColumn + 1))
+    if (nodeLookup.containsKey(alignColumn))
     {
-      nodeIndex = nodeLookup.get(alignColumn + 1);
+      nodeIndex = nodeLookup.get(alignColumn);
       probability = getNode(nodeIndex).getStateTransitions()
               .get(transition);
       return probability;
@@ -483,7 +484,7 @@ public class HiddenMarkovModel
   public Integer getNodeAlignmentColumn(int nodeIndex)
   {
     Integer value = nodes.get(nodeIndex).getAlignmentColumn();
-    return value - 1;
+    return value;
   }
   
   /**
@@ -951,7 +952,7 @@ public class HiddenMarkovModel
   public Integer findNodeIndex(int alignmentColumn)
   {
     Integer index;
-    index = nodeLookup.get(alignmentColumn + 1);
+    index = nodeLookup.get(alignmentColumn);
     return index;
   }
 
@@ -973,81 +974,7 @@ public class HiddenMarkovModel
     }
   }
 
-  /**
-   * Creates the HMM Logo alignment annotation, and populates it with
-   * information content data.
-   * 
-   * @return The alignment annotation.
-   */
-  public AlignmentAnnotation createAnnotation(int length)
-  {
-    Annotation[] annotations = new Annotation[length];
-    float max = 0f;
-    for (int alignPos = 0; alignPos < length; alignPos++)
-    {
-      Float content = getInformationContent(alignPos);
-      if (content > max)
-      {
-        max = content;
-      }
-
-      Character cons;
-
-      cons = getConsensusAtAlignColumn(alignPos);
-
-      cons = Character.toUpperCase(cons);
-
-      String description = String.format("%.3f", content);
-      description += " bits";
-      annotations[alignPos] = new Annotation(cons.toString(), description,
-              ' ',
-              content);
-
-    }
-    AlignmentAnnotation annotation = new AlignmentAnnotation(
-            "Information",
-            "The information content of each column, measured in bits",
-            annotations,
-            0f, max, AlignmentAnnotation.BAR_GRAPH);
-    annotation.setHMM(this);
-    return annotation;
-  }
-
-  /**
-   * Returns the information content at a specified column.
-   * 
-   * @param column
-   *          Index of the column, starting from 0.
-   * @return
-   */
-  public float getInformationContent(int column)
-  {
-    float informationContent = 0f;
-
-    for (char symbol : symbols)
-    {
-      float freq = 0f;
-      if ("amino".equals(getAlphabetType()))
-      {
-        freq = ResidueProperties.aminoBackgroundFrequencies.get(symbol);
-      }
-      if ("DNA".equals(getAlphabetType()))
-      {
-        freq = ResidueProperties.dnaBackgroundFrequencies.get(symbol);
-      }
-      if ("RNA".equals(getAlphabetType()))
-      {
-        freq = ResidueProperties.rnaBackgroundFrequencies
-                .get(symbol);
-      }
-      Double hmmProb = getMatchEmissionProbability(column, symbol);
-      float prob = hmmProb.floatValue();
-      informationContent += prob * (Math.log(prob / freq) / Math.log(2));
 
-    }
-
-    return informationContent;
-  }
 
   /**
    * Returns the consensus sequence based on the most probable symbol at each
@@ -1058,7 +985,7 @@ public class HiddenMarkovModel
    *          The length of the longest sequence in the existing alignment.
    * @return
    */
-  public Sequence getConsensusSequence(int length)
+  public Sequence getConsensusSequence()
   {
     int start;
     int end;
@@ -1066,8 +993,8 @@ public class HiddenMarkovModel
     start = getNodeAlignmentColumn(1);
     modelLength = getLength();
     end = getNodeAlignmentColumn(modelLength);
-    char[] sequence = new char[length];
-    for (int index = 0; index < length; index++)
+    char[] sequence = new char[end];
+    for (int index = 0; index < end; index++)
     {
       Character character;
 
@@ -1084,7 +1011,7 @@ public class HiddenMarkovModel
       }
 
 
-    Sequence seq = new Sequence("HMM CONSENSUS", sequence, start, end);
+    Sequence seq = new Sequence(getName() + "_HMM", sequence, start, end);
     return seq;
   }
 
@@ -1093,7 +1020,7 @@ public class HiddenMarkovModel
    * Maps the nodes of the hidden Markov model to the reference annotation and
    * then deletes this annotation.
    */
-  public void mapToReferenceAnnotation(AlignFrame af)
+  public void mapToReferenceAnnotation(AlignFrame af, SequenceI seq)
   {
     AlignmentAnnotation annotArray[] = af.getViewport().getAlignment()
             .getAlignmentAnnotation();
@@ -1112,6 +1039,58 @@ public class HiddenMarkovModel
       return;
     }
 
+    mapToReferenceAnnotation(reference, seq);
+    af.getViewport().getAlignment().deleteAnnotation(reference);
+  }
+
+  public void mapToReferenceAnnotation(AlignmentAnnotation reference,
+          SequenceI seq)
+  {
+    HiddenMarkovModel hmm = seq.getHMM();
+    Annotation[] annots = reference.annotations;
+    {
+      int nodeIndex = 0;
+      for (int col = 0; col < annots.length; col++)
+      {
+        String character = annots[col].displayCharacter;
+        if ("x".equals(character) || "X".equals(character))
+        {
+          nodeIndex++;
+          if (nodeIndex < hmm.getNodes().size())
+          {
+            HMMNode node = hmm.getNode(nodeIndex);
+            int alignPos = getNodeAlignmentColumn(nodeIndex);
+            char seqCharacter = seq.getCharAt(alignPos);
+            if (alignPos >= seq.getLength() || col >= seq.getLength())
+            {
+              seq.insertCharAt(seq.getLength(),
+                      (alignPos + 1) - seq.getLength(),
+                      '-');
+            }
+            seq.getSequence()[alignPos] = '-';
+            seq.getSequence()[col] = seqCharacter;
+            node.setAlignmentColumn(col);
+            hmm.nodeLookup.put(col, nodeIndex);
+          }
+          else
+          {
+            System.out.println(
+                    "The reference annotation contains more consensus columns than the hidden Markov model");
+            break;
+          }
+        }
+        else
+        {
+          hmm.nodeLookup.remove(col);
+        }
+      }
+
+    }
+
+  }
+
+  public void mapToReferenceAnnotation(AlignmentAnnotation reference)
+  {
     Annotation[] annots = reference.annotations;
     {
       int nodeIndex = 0;
@@ -1123,8 +1102,9 @@ public class HiddenMarkovModel
           nodeIndex++;
           if (nodeIndex < nodes.size())
           {
-            nodes.get(nodeIndex).setAlignmentColumn(col + 1);
-            nodeLookup.put(col + 1, nodeIndex);
+            HMMNode node = nodes.get(nodeIndex);
+            node.setAlignmentColumn(col + 1);
+            nodeLookup.put(col, nodeIndex);
           }
           else
           {
@@ -1135,25 +1115,22 @@ public class HiddenMarkovModel
         }
         else
         {
-          nodeLookup.remove(col + 1);
+          nodeLookup.remove(col);
         }
       }
 
     }
-    af.getViewport().getAlignment().deleteAnnotation(reference);
+
   }
 
-  public void initPlaceholder(AlignFrame af)
+  public SequenceI initHMMSequence()
   {
-    AlignmentI alignment = af.getViewport().getAlignment();
-    int length = alignment.getWidth();
-    Sequence consensus = getConsensusSequence(length);
+    Sequence consensus = getConsensusSequence();
+    consensus.setIsHMMConsensusSequence(true);
     consensus.setHMM(this);
-    SequenceI[] consensusArr = new Sequence[] { consensus };
-    AlignmentI newAlignment = new Alignment(consensusArr);
-    newAlignment.append(alignment);
-    af.getViewport().setAlignment(newAlignment);
+    return consensus;
   }
 
+
 }