Merge branch 'alpha/JAL-3362_Jalview_212_alpha' into alpha/merge_212_JalviewJS_2112
[jalview.git] / src / jalview / datamodel / HiddenMarkovModel.java
diff --git a/src/jalview/datamodel/HiddenMarkovModel.java b/src/jalview/datamodel/HiddenMarkovModel.java
new file mode 100644 (file)
index 0000000..6b8b095
--- /dev/null
@@ -0,0 +1,672 @@
+package jalview.datamodel;
+
+import jalview.io.HMMFile;
+import jalview.schemes.ResidueProperties;
+import jalview.util.Comparison;
+import jalview.util.MapList;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Data structure which stores a hidden Markov model
+ * 
+ * @author TZVanaalten
+ * 
+ */
+public class HiddenMarkovModel
+{
+  private static final char GAP_DASH = '-';
+
+  public final static String YES = "yes";
+
+  public final static String NO = "no";
+
+  public static final int MATCHTOMATCH = 0;
+
+  public static final int MATCHTOINSERT = 1;
+
+  public static final int MATCHTODELETE = 2;
+
+  public static final int INSERTTOMATCH = 3;
+
+  public static final int INSERTTOINSERT = 4;
+
+  public static final int DELETETOMATCH = 5;
+
+  public static final int DELETETODELETE = 6;
+
+  private static final double LOG2 = Math.log(2);
+
+  /*
+   * properties read from HMM file header lines
+   */
+  private Map<String, String> fileProperties = new HashMap<>();
+
+  private String fileHeader;
+  
+  /*
+   * the symbols used in this model e.g. "ACGT"
+   */
+  private String alphabet;
+
+  /*
+   * symbol lookup index into the alphabet for 'A' to 'Z'
+   */
+  private int[] symbolIndexLookup = new int['Z' - 'A' + 1];
+
+  /*
+   * Nodes in the model. The begin node is at index 0, and contains 
+   * average emission probabilities for each symbol.
+   */
+  private List<HMMNode> nodes = new ArrayList<>();
+
+  /*
+   * the aligned HMM consensus sequence extracted from the HMM profile
+   */
+  private SequenceI hmmSeq;
+
+  /*
+   * mapping from HMM nodes to residues of the hmm consensus sequence
+   */
+  private Mapping mapToHmmConsensus;
+
+  // stores background frequencies of alignment from which this model came
+  private Map<Character, Float> backgroundFrequencies;
+
+  /**
+   * Constructor
+   */
+  public HiddenMarkovModel()
+  {
+  }
+
+  /**
+   * Copy constructor given a new aligned sequence with which to associate the
+   * HMM profile
+   * 
+   * @param hmm
+   * @param sq
+   */
+  public HiddenMarkovModel(HiddenMarkovModel hmm, SequenceI sq)
+  {
+    super();
+    this.fileProperties = new HashMap<>(hmm.fileProperties);
+    this.alphabet = hmm.alphabet;
+    this.nodes = new ArrayList<>(hmm.nodes);
+    this.symbolIndexLookup = hmm.symbolIndexLookup;
+    this.fileHeader = new String(hmm.fileHeader);
+    this.hmmSeq = sq;
+    this.backgroundFrequencies = hmm.getBackgroundFrequencies();
+    if (sq.getDatasetSequence() == hmm.mapToHmmConsensus.getTo())
+    {
+      // same dataset sequence e.g. after realigning search results
+      this.mapToHmmConsensus = hmm.mapToHmmConsensus;
+    }
+    else
+    {
+      // different dataset sequence e.g. after loading HMM from project
+      this.mapToHmmConsensus = new Mapping(sq.getDatasetSequence(),
+              hmm.mapToHmmConsensus.getMap());
+    }
+  }
+
+  /**
+   * Returns the information content at a specified column, calculated as the
+   * sum (over possible symbols) of the log ratio
+   * 
+   * <pre>
+   *  log(emission probability / background probability) / log(2)
+   * </pre>
+   * 
+   * @param column
+   *          column position (base 0)
+   * @return
+   */
+  public float getInformationContent(int column)
+  {
+    float informationContent = 0f;
+
+    for (char symbol : getSymbols().toCharArray())
+    {
+      float freq = ResidueProperties.backgroundFrequencies
+              .get(getAlphabetType()).get(symbol);
+      float prob = (float) getMatchEmissionProbability(column, symbol);
+      informationContent += prob * Math.log(prob / freq);
+    }
+
+    informationContent = informationContent / (float) LOG2;
+
+    return informationContent;
+  }
+
+  /**
+   * Gets the file header of the .hmm file this model came from
+   * 
+   * @return
+   */
+  public String getFileHeader()
+  {
+    return fileHeader;
+  }
+
+  /**
+   * Sets the file header of this model.
+   * 
+   * @param header
+   */
+  public void setFileHeader(String header)
+  {
+    fileHeader = header;
+  }
+
+  /**
+   * Returns the symbols used in this hidden Markov model
+   * 
+   * @return
+   */
+  public String getSymbols()
+  {
+    return alphabet;
+  }
+  
+  /**
+   * Gets the node in the hidden Markov model at the specified position.
+   * 
+   * @param nodeIndex
+   *          The index of the node requested. Node 0 optionally contains the
+   *          average match emission probabilities across the entire model, and
+   *          always contains the insert emission probabilities and state
+   *          transition probabilities for the begin node. Node 1 contains the
+   *          first node in the HMM that can correspond to a column in the
+   *          alignment.
+   * @return
+   */
+  public HMMNode getNode(int nodeIndex)
+  {
+    return nodes.get(nodeIndex);
+  }
+
+  /**
+   * Returns the name of the sequence alignment on which the HMM is based.
+   * 
+   * @return
+   */
+  public String getName()
+  {
+    return fileProperties.get(HMMFile.NAME);
+  }
+  
+  /**
+   * Answers the string value of the property (parsed from an HMM file) for the
+   * given key, or null if the property is not present
+   * 
+   * @param key
+   * @return
+   */
+  public String getProperty(String key)
+  {
+    return fileProperties.get(key);
+  }
+
+  /**
+   * Answers true if the property with the given key is present with a value of
+   * "yes" (not case-sensitive), else false
+   * 
+   * @param key
+   * @return
+   */
+  public boolean getBooleanProperty(String key)
+  {
+    return YES.equalsIgnoreCase(fileProperties.get(key));
+  }
+
+  /**
+   * Returns the length of the hidden Markov model. The value returned is the
+   * LENG property if specified, else the number of nodes, excluding the begin
+   * node (which should be the same thing).
+   * 
+   * @return
+   */
+  public int getLength()
+  {
+    if (fileProperties.get(HMMFile.LENGTH) == null)
+    {
+      return nodes.size() - 1; // not counting BEGIN node
+    }
+    return Integer.parseInt(fileProperties.get(HMMFile.LENGTH));
+  }
+
+  /**
+   * Returns the value of mandatory property "ALPH" - "amino", "DNA", "RNA" are
+   * the options. Other alphabets may be added.
+   * 
+   * @return
+   */
+  public String getAlphabetType()
+  {
+    return fileProperties.get(HMMFile.ALPHABET);
+  }
+
+  /**
+   * Sets the model alphabet to the symbols in the given string (ignoring any
+   * whitespace), and returns the number of symbols
+   * 
+   * @param symbols
+   */
+  public int setAlphabet(String symbols)
+  {
+    String trimmed = symbols.toUpperCase().replaceAll("\\s", "");
+    int count = trimmed.length();
+    alphabet = trimmed;
+    symbolIndexLookup = new int['Z' - 'A' + 1];
+    Arrays.fill(symbolIndexLookup, -1);
+    int ignored = 0;
+
+    /*
+     * save the symbols in order, and a quick lookup of symbol position
+     */
+    for (short i = 0; i < count; i++)
+    {
+      char symbol = trimmed.charAt(i);
+      if (symbol >= 'A' && symbol <= 'Z'
+              && symbolIndexLookup[symbol - 'A'] == -1)
+      {
+        symbolIndexLookup[symbol - 'A'] = i;
+      }
+      else
+      {
+        System.err
+                .println(
+                        "Unexpected or duplicated character in HMM ALPHabet: "
+                                + symbol);
+        ignored++;
+      }
+    }
+    return count - ignored;
+  }
+
+  /**
+   * Answers the node of the model corresponding to an aligned column position
+   * (0...), or null if there is no such node
+   * 
+   * @param column
+   * @return
+   */
+  HMMNode getNodeForColumn(int column)
+  {
+    /*
+     * if the hmm consensus is gapped at the column,
+     * there is no corresponding node
+     */
+    if (Comparison.isGap(hmmSeq.getCharAt(column)))
+    {
+      return null;
+    }
+
+    /*
+     * find the node (if any) that is mapped to the
+     * consensus sequence residue position at the column
+     */
+    int seqPos = hmmSeq.findPosition(column);
+    int[] nodeNo = mapToHmmConsensus.getMap().locateInFrom(seqPos, seqPos);
+    if (nodeNo != null)
+    {
+      return getNode(nodeNo[0]);
+    }
+    return null;
+  }
+
+  /**
+   * Gets the match emission probability for a given symbol at a column in the
+   * alignment.
+   * 
+   * @param alignColumn
+   *          The index of the alignment column, starting at index 0. Index 0
+   *          usually corresponds to index 1 in the HMM.
+   * @param symbol
+   *          The symbol for which the desired probability is being requested.
+   * @return
+   * 
+   */
+  public double getMatchEmissionProbability(int alignColumn, char symbol)
+  {
+    HMMNode node = getNodeForColumn(alignColumn);
+    int symbolIndex = getSymbolIndex(symbol);
+    if (node != null && symbolIndex != -1)
+    {
+      return node.getMatchEmission(symbolIndex);
+    }
+    return 0D;
+  }
+
+  /**
+   * Gets the insert emission probability for a given symbol at a column in the
+   * alignment.
+   * 
+   * @param alignColumn
+   *          The index of the alignment column, starting at index 0. Index 0
+   *          usually corresponds to index 1 in the HMM.
+   * @param symbol
+   *          The symbol for which the desired probability is being requested.
+   * @return
+   * 
+   */
+  public double getInsertEmissionProbability(int alignColumn, char symbol)
+  {
+    HMMNode node = getNodeForColumn(alignColumn);
+    int symbolIndex = getSymbolIndex(symbol);
+    if (node != null && symbolIndex != -1)
+    {
+      return node.getInsertEmission(symbolIndex);
+    }
+    return 0D;
+  }
+  
+  /**
+   * Gets the state transition probability for a given symbol at a column in the
+   * alignment.
+   * 
+   * @param alignColumn
+   *          The index of the alignment column, starting at index 0. Index 0
+   *          usually corresponds to index 1 in the HMM.
+   * @param symbol
+   *          The symbol for which the desired probability is being requested.
+   * @return
+   * 
+   */
+  public double getStateTransitionProbability(int alignColumn,
+          int transition)
+  {
+    HMMNode node = getNodeForColumn(alignColumn);
+    if (node != null)
+    {
+      return node.getStateTransition(transition);
+    }
+    return 0D;
+  }
+  
+  /**
+   * Returns the sequence position linked to the node at the given index. This
+   * corresponds to an aligned column position (counting from 1).
+   * 
+   * @param nodeIndex
+   *          The index of the node, starting from index 1. Index 0 is the begin
+   *          node, which does not correspond to a column in the alignment.
+   * @return
+   */
+  public int getNodeMapPosition(int nodeIndex)
+  {
+    return nodes.get(nodeIndex).getResidueNumber();
+  }
+  
+  /**
+   * Returns the consensus residue at the specified node.
+   * 
+   * @param nodeIndex
+   *          The index of the specified node.
+   * @return
+   */
+  public char getConsensusResidue(int nodeIndex)
+  {
+   char value = nodes.get(nodeIndex).getConsensusResidue();
+   return value;
+  }
+  
+  /**
+   * Returns the reference annotation at the specified node.
+   * 
+   * @param nodeIndex
+   *          The index of the specified node.
+   * @return
+   */
+  public char getReferenceAnnotation(int nodeIndex)
+  {
+   char value = nodes.get(nodeIndex).getReferenceAnnotation();
+   return value;
+  }
+  
+  /**
+   * Returns the mask value at the specified node.
+   * 
+   * @param nodeIndex
+   *          The index of the specified node.
+   * @return
+   */
+  public char getMaskedValue(int nodeIndex)
+  {
+   char value = nodes.get(nodeIndex).getMaskValue();
+   return value;
+  }
+  
+  /**
+   * Returns the consensus structure at the specified node.
+   * 
+   * @param nodeIndex
+   *          The index of the specified node.
+   * @return
+   */
+  public char getConsensusStructure(int nodeIndex)
+  {
+   char value = nodes.get(nodeIndex).getConsensusStructure();
+   return value;
+  }
+  
+  /**
+   * Sets a property read from an HMM file
+   * 
+   * @param key
+   * @param value
+   */
+  public void setProperty(String key, String value)
+  {
+    fileProperties.put(key, value);
+  }
+
+  /**
+   * Temporary implementation, should not be used.
+   * 
+   * @return
+   */
+  public String getViterbi()
+  {
+    String value;
+    value = fileProperties.get(HMMFile.VITERBI);
+    return value;
+  }
+
+  /**
+   * Temporary implementation, should not be used.
+   * 
+   * @return
+   */
+  public String getMSV()
+  {
+    String value;
+    value = fileProperties.get(HMMFile.MSV);
+    return value;
+  }
+
+  /**
+   * Temporary implementation, should not be used.
+   * 
+   * @return
+   */
+  public String getForward()
+  {
+    String value;
+    value = fileProperties.get(HMMFile.FORWARD);
+    return value;
+  }
+
+  /**
+   * Constructs the consensus sequence based on the most probable symbol at each
+   * position. Gap characters are inserted for discontinuities in the node map
+   * numbering (if provided), else an ungapped sequence is generated.
+   * <p>
+   * A mapping between the HMM nodes and residue positions of the sequence is
+   * also built and saved.
+   * 
+   * @return
+   */
+  void buildConsensusSequence()
+  {
+    List<int[]> toResidues = new ArrayList<>();
+
+    /*
+     * if the HMM provided a map to sequence, use those start/end values,
+     * else just treat it as for a contiguous sequence numbered from 1
+     */
+    boolean hasMap = getBooleanProperty(HMMFile.MAP);
+    int start = hasMap ? getNode(1).getResidueNumber() : 1;
+    int endResNo = hasMap ? getNode(nodes.size() - 1).getResidueNumber()
+            : (start + getLength() - 1);
+    char[] sequence = new char[endResNo];
+
+    int lastResNo = start - 1;
+    int seqOffset = -1;
+    int gapCount = 0;
+
+
+    for (int seqN = 0; seqN < start; seqN++)
+    {
+      sequence[seqN] = GAP_DASH;
+      seqOffset++;
+    }
+    
+    for (int nodeNo = 1; nodeNo < nodes.size(); nodeNo++)
+    {
+      HMMNode node = nodes.get(nodeNo);
+      final int resNo = hasMap ? node.getResidueNumber() : lastResNo + 1;
+
+      /*
+       * insert gaps if map numbering is not continuous
+       */
+      while (resNo > lastResNo + 1)
+      {
+        sequence[seqOffset++] = GAP_DASH;
+        lastResNo++;
+        gapCount++;
+      }
+      char consensusResidue = node.getConsensusResidue();
+      if (GAP_DASH == consensusResidue)
+      {
+        /*
+         * no residue annotation in HMM - scan for the symbol
+         * with the highest match emission probability
+         */
+        int symbolIndex = node.getMaxMatchEmissionIndex();
+        consensusResidue = alphabet.charAt(symbolIndex);
+        if (node.getMatchEmission(symbolIndex) < 0.5D)
+        {
+          // follow convention of lower case if match emission prob < 0.5
+          consensusResidue = Character.toLowerCase(consensusResidue);
+        }
+      }
+      sequence[seqOffset++] = consensusResidue;
+      lastResNo = resNo;
+    }
+
+    Sequence seq = new Sequence(getName(), sequence, start,
+            lastResNo - gapCount);
+    seq.createDatasetSequence();
+    seq.setHMM(this);
+    this.hmmSeq = seq;
+
+    /*
+     * construct and store Mapping of nodes to residues
+     * note as constructed this is just an identity mapping, 
+     * but it allows for greater flexibility in future
+     */
+    List<int[]> fromNodes = new ArrayList<>();
+    fromNodes.add(new int[] { 1, getLength() });
+    toResidues.add(new int[] { seq.getStart(), seq.getEnd() });
+    MapList mapList = new MapList(fromNodes, toResidues, 1, 1);
+    mapToHmmConsensus = new Mapping(seq.getDatasetSequence(), mapList);
+  }
+
+
+  /**
+   * Answers the aligned consensus sequence for the profile. Note this will
+   * return null if called before <code>setNodes</code> has been called.
+   * 
+   * @return
+   */
+  public SequenceI getConsensusSequence()
+  {
+    return hmmSeq;
+  }
+
+  /**
+   * Answers the index position (0...) of the given symbol, or -1 if not a valid
+   * symbol for this HMM
+   * 
+   * @param symbol
+   * @return
+   */
+  private int getSymbolIndex(char symbol)
+  {
+    /*
+     * symbolIndexLookup holds the index for 'A' to 'Z'
+     */
+    char c = Character.toUpperCase(symbol);
+    if ('A' <= c && c <= 'Z')
+    {
+      return symbolIndexLookup[c - 'A'];
+    }
+    return -1;
+  }
+
+  /**
+   * Sets the nodes of this HMM, and also extracts the HMM consensus sequence
+   * and a mapping between node numbers and sequence positions
+   * 
+   * @param nodeList
+   */
+  public void setNodes(List<HMMNode> nodeList)
+  {
+    nodes = nodeList;
+    if (nodes.size() > 1)
+    {
+      buildConsensusSequence();
+    }
+  }
+
+  /**
+   * Sets the aligned consensus sequence this HMM is the model for
+   * 
+   * @param hmmSeq
+   */
+  public void setHmmSeq(SequenceI hmmSeq)
+  {
+    this.hmmSeq = hmmSeq;
+  }
+
+  public void setBackgroundFrequencies(Map<Character, Float> bkgdFreqs)
+  {
+    backgroundFrequencies = bkgdFreqs;
+  }
+
+  public void setBackgroundFrequencies(ResidueCount bkgdFreqs)
+  {
+    backgroundFrequencies = new HashMap<>();
+
+    int total = bkgdFreqs.getTotalResidueCount();
+
+    for (char c : bkgdFreqs.getSymbolCounts().symbols)
+    {
+      backgroundFrequencies.put(c, bkgdFreqs.getCount(c) * 1f / total);
+    }
+
+  }
+
+  public Map<Character, Float> getBackgroundFrequencies()
+  {
+    return backgroundFrequencies;
+  }
+
+}
+