JAL-3291 issue fixed, passing tests added
[jalview.git] / src / jalview / datamodel / HiddenMarkovModel.java
index 581f481..e917474 100644 (file)
@@ -3,6 +3,7 @@ package jalview.datamodel;
 import jalview.io.HMMFile;
 import jalview.schemes.ResidueProperties;
 import jalview.util.Comparison;
+import jalview.util.MapList;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -18,6 +19,8 @@ import java.util.Map;
  */
 public class HiddenMarkovModel
 {
+  private static final char GAP_DASH = '-';
+
   public final static String YES = "yes";
 
   public final static String NO = "no";
@@ -41,30 +44,35 @@ public class HiddenMarkovModel
   /*
    * properties read from HMM file header lines
    */
-  Map<String, String> fileProperties = new HashMap<>();
+  private Map<String, String> fileProperties = new HashMap<>();
 
-  String fileHeader;
+  private String fileHeader;
   
   /*
    * the symbols used in this model e.g. "ACGT"
    */
-  String alphabet;
+  private String alphabet;
 
   /*
    * symbol lookup index into the alphabet for 'A' to 'Z'
    */
-  int[] symbolIndexLookup = new int['Z' - 'A' + 1];
+  private int[] symbolIndexLookup = new int['Z' - 'A' + 1];
 
   /*
    * Nodes in the model. The begin node is at index 0, and contains 
    * average emission probabilities for each symbol.
    */
-  List<HMMNode> nodes = new ArrayList<>();
+  private List<HMMNode> nodes = new ArrayList<>();
 
   /*
-   * lookup of the HMM node for each alignment column (from 0)
+   * the aligned HMM consensus sequence extracted from the HMM profile
    */
-  Map<Integer, HMMNode> nodeLookup = new HashMap<>();
+  private SequenceI hmmSeq;
+
+  /*
+   * mapping from HMM nodes to residues of the hmm consensus sequence
+   */
+  private Mapping mapToHmmConsensus;
 
   /**
    * Constructor
@@ -73,15 +81,33 @@ public class HiddenMarkovModel
   {
   }
 
-  public HiddenMarkovModel(HiddenMarkovModel hmm)
+  /**
+   * Copy constructor given a new aligned sequence with which to associate the
+   * HMM profile
+   * 
+   * @param hmm
+   * @param sq
+   */
+  public HiddenMarkovModel(HiddenMarkovModel hmm, SequenceI sq)
   {
     super();
     this.fileProperties = new HashMap<>(hmm.fileProperties);
     this.alphabet = hmm.alphabet;
     this.nodes = new ArrayList<>(hmm.nodes);
-    this.nodeLookup = new HashMap<>(hmm.nodeLookup);
     this.symbolIndexLookup = hmm.symbolIndexLookup;
     this.fileHeader = new String(hmm.fileHeader);
+    this.hmmSeq = sq;
+    if (sq.getDatasetSequence() == hmm.mapToHmmConsensus.getTo())
+    {
+      // same dataset sequence e.g. after realigning search results
+      this.mapToHmmConsensus = hmm.mapToHmmConsensus;
+    }
+    else
+    {
+      // different dataset sequence e.g. after loading HMM from project
+      this.mapToHmmConsensus = new Mapping(sq.getDatasetSequence(),
+              hmm.mapToHmmConsensus.getMap());
+    }
   }
 
   /**
@@ -195,7 +221,9 @@ public class HiddenMarkovModel
   }
 
   /**
-   * Returns the length of the hidden Markov model, or 0 if not known
+   * Returns the length of the hidden Markov model. The value returned is the
+   * LENG property if specified, else the number of nodes, excluding the begin
+   * node (which should be the same thing).
    * 
    * @return
    */
@@ -203,7 +231,7 @@ public class HiddenMarkovModel
   {
     if (fileProperties.get(HMMFile.LENGTH) == null)
     {
-      return 0;
+      return nodes.size() - 1; // not counting BEGIN node
     }
     return Integer.parseInt(fileProperties.get(HMMFile.LENGTH));
   }
@@ -258,17 +286,36 @@ public class HiddenMarkovModel
   }
 
   /**
-   * Sets the list of nodes in this HMM to the given list.
+   * Answers the node of the model corresponding to an aligned column position
+   * (0...), or null if there is no such node
    * 
-   * @param nodes
-   *          The list of nodes to which the current list of nodes is being
-   *          changed.
+   * @param column
+   * @return
    */
-  public void setNodes(List<HMMNode> nodes)
+  HMMNode getNodeForColumn(int column)
   {
-    this.nodes = nodes;
+    /*
+     * if the hmm consensus is gapped at the column,
+     * there is no corresponding node
+     */
+    if (Comparison.isGap(hmmSeq.getCharAt(column)))
+    {
+      return null;
+    }
+
+    /*
+     * find the node (if any) that is mapped to the
+     * consensus sequence residue position at the column
+     */
+    int seqPos = hmmSeq.findPosition(column);
+    int[] nodeNo = mapToHmmConsensus.getMap().locateInFrom(seqPos, seqPos);
+    if (nodeNo != null)
+    {
+      return getNode(nodeNo[0]);
+    }
+    return null;
   }
-  
+
   /**
    * Gets the match emission probability for a given symbol at a column in the
    * alignment.
@@ -283,14 +330,13 @@ public class HiddenMarkovModel
    */
   public double getMatchEmissionProbability(int alignColumn, char symbol)
   {
+    HMMNode node = getNodeForColumn(alignColumn);
     int symbolIndex = getSymbolIndex(symbol);
-    double probability = 0d;
-    if (symbolIndex != -1 && nodeLookup.containsKey(alignColumn))
+    if (node != null && symbolIndex != -1)
     {
-      HMMNode node = nodeLookup.get(alignColumn);
-      probability = node.getMatchEmission(symbolIndex);
+      return node.getMatchEmission(symbolIndex);
     }
-    return probability;
+    return 0D;
   }
 
   /**
@@ -307,14 +353,13 @@ public class HiddenMarkovModel
    */
   public double getInsertEmissionProbability(int alignColumn, char symbol)
   {
+    HMMNode node = getNodeForColumn(alignColumn);
     int symbolIndex = getSymbolIndex(symbol);
-    double probability = 0d;
-    if (symbolIndex != -1 && nodeLookup.containsKey(alignColumn))
+    if (node != null && symbolIndex != -1)
     {
-      HMMNode node = nodeLookup.get(alignColumn);
-      probability = node.getInsertEmission(symbolIndex);
+      return node.getInsertEmission(symbolIndex);
     }
-    return probability;
+    return 0D;
   }
   
   /**
@@ -329,30 +374,29 @@ public class HiddenMarkovModel
    * @return
    * 
    */
-  public Double getStateTransitionProbability(int alignColumn,
+  public double getStateTransitionProbability(int alignColumn,
           int transition)
   {
-    double probability = 0d;
-    if (nodeLookup.containsKey(alignColumn))
+    HMMNode node = getNodeForColumn(alignColumn);
+    if (node != null)
     {
-      HMMNode node = nodeLookup.get(alignColumn);
-      probability = node.getStateTransition(transition);
+      return node.getStateTransition(transition);
     }
-    return probability;
+    return 0D;
   }
   
   /**
-   * Returns the alignment column linked to the node at the given index.
+   * Returns the sequence position linked to the node at the given index. This
+   * corresponds to an aligned column position (counting from 1).
    * 
    * @param nodeIndex
    *          The index of the node, starting from index 1. Index 0 is the begin
    *          node, which does not correspond to a column in the alignment.
    * @return
    */
-  public Integer getNodeAlignmentColumn(int nodeIndex)
+  public int getNodeMapPosition(int nodeIndex)
   {
-    Integer value = nodes.get(nodeIndex).getAlignmentColumn();
-    return value;
+    return nodes.get(nodeIndex).getResidueNumber();
   }
   
   /**
@@ -369,49 +413,6 @@ public class HiddenMarkovModel
   }
   
   /**
-   * Returns the consensus at a given alignment column. If the character is
-   * lower case, its emission probability is less than 0.5.
-   * 
-   * @param columnIndex
-   *          The index of the column in the alignment for which the consensus
-   *          is desired. The list of columns starts at index 0.
-   * @return
-   */
-  public char getConsensusAtAlignColumn(int columnIndex)
-  {
-    char mostLikely = '-';
-    if (getBooleanProperty(HMMFile.CONSENSUS_RESIDUE))
-    {
-      HMMNode node = nodeLookup.get(columnIndex);
-      if (node == null)
-      {
-        return '-';
-      }
-      mostLikely = node.getConsensusResidue();
-      return mostLikely;
-    }
-    else
-    {
-      double highestProb = 0;
-      for (char character : alphabet.toCharArray())
-      {
-        double prob = getMatchEmissionProbability(columnIndex, character);
-        if (prob > highestProb)
-        {
-          highestProb = prob;
-          mostLikely = character;
-        }
-      }
-      if (highestProb < 0.5)
-      {
-        mostLikely = Character.toLowerCase(mostLikely);
-      }
-      return mostLikely;
-    }
-
-  }
-
-  /**
    * Returns the reference annotation at the specified node.
    * 
    * @param nodeIndex
@@ -451,16 +452,6 @@ public class HiddenMarkovModel
   }
   
   /**
-   * Returns the number of symbols in the alphabet used in this HMM.
-   * 
-   * @return
-   */
-  public int getNumberOfSymbols()
-  {
-    return alphabet.length();
-  }
-
-  /**
    * Sets a property read from an HMM file
    * 
    * @param key
@@ -472,97 +463,6 @@ public class HiddenMarkovModel
   }
 
   /**
-   * Sets the alignment column of the specified node
-   * 
-   * @param nodeIndex
-   * 
-   * @param column
-   * 
-   */
-  public void setAlignmentColumn(HMMNode node, int column)
-  {
-    node.setAlignmentColumn(column);
-    nodeLookup.put(column, node);
-  }
-
-  public void updateMapping(char[] sequence)
-  {
-    int nodeNo = 1;
-    int column = 0;
-    synchronized (nodeLookup)
-    {
-      clearNodeLookup();
-      for (char residue : sequence)
-      {
-        if (!Comparison.isGap(residue))
-        {
-          HMMNode node = nodes.get(nodeNo);
-          if (node == null)
-          {
-            // error : too few nodes for sequence
-            break;
-          }
-          setAlignmentColumn(node, column);
-          nodeNo++;
-        }
-        column++;
-      }
-    }
-  }
-
-  /**
-   * Clears all data in the node lookup map
-   */
-  public void clearNodeLookup()
-  {
-    nodeLookup.clear();
-  }
-
-  /**
-   * Sets the reference annotation at a given node
-   * 
-   * @param nodeIndex
-   * @param value
-   */
-  public void setReferenceAnnotation(int nodeIndex, char value)
-  {
-    nodes.get(nodeIndex).setReferenceAnnotation(value);
-  }
-
-  /**
-   * Sets the consensus residue at a given node
-   * 
-   * @param nodeIndex
-   * @param value
-   */
-  public void setConsensusResidue(int nodeIndex, char value)
-  {
-    nodes.get(nodeIndex).setConsensusResidue(value);
-  }
-
-  /**
-   * Sets the consensus structure at a given node
-   * 
-   * @param nodeIndex
-   * @param value
-   */
-  public void setConsensusStructure(int nodeIndex, char value)
-  {
-    nodes.get(nodeIndex).setConsensusStructure(value);
-  }
-
-  /**
-   * Sets the mask value at a given node
-   * 
-   * @param nodeIndex
-   * @param value
-   */
-  public void setMaskValue(int nodeIndex, char value)
-  {
-    nodes.get(nodeIndex).setMaskValue(value);
-  }
-
-  /**
    * Temporary implementation, should not be used.
    * 
    * @return
@@ -599,66 +499,100 @@ public class HiddenMarkovModel
   }
 
   /**
-   * Answers the HMMNode mapped to the given alignment column (base 0), or null
-   * if none is mapped
-   * 
-   * @param alignmentColumn
-   */
-  public HMMNode getNodeForColumn(int alignmentColumn)
-  {
-    return nodeLookup.get(alignmentColumn);
-  }
-
-  /**
-   * Returns the consensus sequence based on the most probable symbol at each
-   * position. The sequence is adjusted to match the length of the existing
-   * sequence alignment. Gap characters are used as padding.
+   * Constructs the consensus sequence based on the most probable symbol at each
+   * position. Gap characters are inserted for discontinuities in the node map
+   * numbering (if provided), else an ungapped sequence is generated.
+   * <p>
+   * A mapping between the HMM nodes and residue positions of the sequence is
+   * also built and saved.
    * 
    * @return
    */
-  public Sequence getConsensusSequence()
+  void buildConsensusSequence()
   {
-    int start;
-    int end;
-    int modelLength;
-    start = getNodeAlignmentColumn(1);
-    modelLength = getLength();
-    end = getNodeAlignmentColumn(modelLength);
-    char[] sequence = new char[end + 1];
-    for (int index = 0; index < end + 1; index++)
+    List<int[]> toResidues = new ArrayList<>();
+
+    /*
+     * if the HMM provided a map to sequence, use those start/end values,
+     * else just treat it as for a contiguous sequence numbered from 1
+     */
+    boolean hasMap = getBooleanProperty(HMMFile.MAP);
+    int start = hasMap ? getNode(1).getResidueNumber() : 1;
+    int endResNo = hasMap ? getNode(nodes.size() - 1).getResidueNumber()
+            : (start + getLength() - 1);
+    char[] sequence = new char[endResNo + 1];
+
+    int lastResNo = start - 1;
+    int seqOffset = -1;
+    int gapCount = 0;
+
+    for (int seqN = 0; seqN < start; seqN++)
     {
-      Character character;
+      sequence[seqN] = GAP_DASH;
+      seqOffset++;
+    }
 
-        character = getConsensusAtAlignColumn(index);
+    for (int nodeNo = 1; nodeNo < nodes.size(); nodeNo++)
+    {
+      HMMNode node = nodes.get(nodeNo);
+      final int resNo = hasMap ? node.getResidueNumber() : lastResNo + 1;
 
-      if (character == null || character == '-')
+      /*
+       * insert gaps if map numbering is not continuous
+       */
+      while (resNo > lastResNo + 1)
       {
-        sequence[index] = '-';
+        sequence[seqOffset++] = GAP_DASH;
+        lastResNo++;
+        gapCount++;
       }
-      else
+      char consensusResidue = node.getConsensusResidue();
+      if (GAP_DASH == consensusResidue)
       {
-        sequence[index] = Character.toUpperCase(character);
-      }
+        /*
+         * no residue annotation in HMM - scan for the symbol
+         * with the highest match emission probability
+         */
+        int symbolIndex = node.getMaxMatchEmissionIndex();
+        consensusResidue = alphabet.charAt(symbolIndex);
+        if (node.getMatchEmission(symbolIndex) < 0.5D)
+        {
+          // follow convention of lower case if match emission prob < 0.5
+          consensusResidue = Character.toLowerCase(consensusResidue);
+        }
       }
-
+      sequence[seqOffset++] = consensusResidue;
+      lastResNo = resNo;
+    }
 
     Sequence seq = new Sequence(getName(), sequence, start,
-            end);
-    return seq;
+            lastResNo - gapCount);
+    seq.createDatasetSequence();
+    seq.setHMM(this);
+    this.hmmSeq = seq;
+
+    /*
+     * construct and store Mapping of nodes to residues
+     * note as constructed this is just an identity mapping, 
+     * but it allows for greater flexibility in future
+     */
+    List<int[]> fromNodes = new ArrayList<>();
+    fromNodes.add(new int[] { 1, getLength() });
+    toResidues.add(new int[] { seq.getStart(), seq.getEnd() });
+    MapList mapList = new MapList(fromNodes, toResidues, 1, 1);
+    mapToHmmConsensus = new Mapping(seq.getDatasetSequence(), mapList);
   }
 
 
   /**
-   * Initiates a HMM consensus sequence
+   * Answers the aligned consensus sequence for the profile. Note this will
+   * return null if called before <code>setNodes</code> has been called.
    * 
-   * @return A new HMM consensus sequence
+   * @return
    */
-  public SequenceI initHMMSequence()
+  public SequenceI getConsensusSequence()
   {
-    Sequence consensus = getConsensusSequence();
-    consensus.setIsHMMConsensusSequence(true);
-    consensus.setHMM(this);
-    return consensus;
+    return hmmSeq;
   }
 
   /**
@@ -668,7 +602,7 @@ public class HiddenMarkovModel
    * @param symbol
    * @return
    */
-  public int getSymbolIndex(char symbol)
+  private int getSymbolIndex(char symbol)
   {
     /*
      * symbolIndexLookup holds the index for 'A' to 'Z'
@@ -681,9 +615,29 @@ public class HiddenMarkovModel
     return -1;
   }
 
-  public void addNode(HMMNode node)
+  /**
+   * Sets the nodes of this HMM, and also extracts the HMM consensus sequence
+   * and a mapping between node numbers and sequence positions
+   * 
+   * @param nodeList
+   */
+  public void setNodes(List<HMMNode> nodeList)
+  {
+    nodes = nodeList;
+    if (nodes.size() > 1)
+    {
+      buildConsensusSequence();
+    }
+  }
+
+  /**
+   * Sets the aligned consensus sequence this HMM is the model for
+   * 
+   * @param hmmSeq
+   */
+  public void setHmmSeq(SequenceI hmmSeq)
   {
-    nodes.add(node);
+    this.hmmSeq = hmmSeq;
   }
 }