initial commit
[jalview.git] / forester / java / src / org / forester / evoinference / parsimony / FitchParsimony.java
1 // $Id:
2 //
3 // FORESTER -- software libraries and applications
4 // for evolutionary biology research and applications.
5 //
6 // Copyright (C) 2008-2009 Christian M. Zmasek
7 // Copyright (C) 2008-2009 Burnham Institute for Medical Research
8 // All rights reserved
9 // 
10 // This library is free software; you can redistribute it and/or
11 // modify it under the terms of the GNU Lesser General Public
12 // License as published by the Free Software Foundation; either
13 // version 2.1 of the License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful,
16 // but WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 // 
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
23 //
24 // Contact: phylosoft @ gmail . com
25 // WWW: www.phylosoft.org/forester
26
27 package org.forester.evoinference.parsimony;
28
29 import java.util.ArrayList;
30 import java.util.HashMap;
31 import java.util.Iterator;
32 import java.util.List;
33 import java.util.Map;
34 import java.util.Random;
35 import java.util.SortedSet;
36 import java.util.TreeSet;
37
38 import org.forester.evoinference.matrix.character.BasicCharacterStateMatrix;
39 import org.forester.evoinference.matrix.character.CharacterStateMatrix;
40 import org.forester.evoinference.matrix.character.CharacterStateMatrix.BinaryStates;
41 import org.forester.evoinference.matrix.character.CharacterStateMatrix.GainLossStates;
42 import org.forester.phylogeny.Phylogeny;
43 import org.forester.phylogeny.PhylogenyNode;
44 import org.forester.phylogeny.iterators.PhylogenyNodeIterator;
45 import org.forester.util.FailedConditionCheckException;
46 import org.forester.util.ForesterUtil;
47
48 public class FitchParsimony<STATE_TYPE> {
49
50     final static private BinaryStates              PRESENT                         = BinaryStates.PRESENT;
51     final static private BinaryStates              ABSENT                          = BinaryStates.ABSENT;
52     final static private GainLossStates            LOSS                            = GainLossStates.LOSS;
53     final static private GainLossStates            GAIN                            = GainLossStates.GAIN;
54     final static private GainLossStates            UNCHANGED_PRESENT               = GainLossStates.UNCHANGED_PRESENT;
55     final static private GainLossStates            UNCHANGED_ABSENT                = GainLossStates.UNCHANGED_ABSENT;
56     private static final boolean                   RETURN_INTERNAL_STATES_DEFAULT  = false;
57     private static final boolean                   RETURN_GAIN_LOSS_MATRIX_DEFAULT = false;
58     private static final boolean                   RANDOMIZE_DEFAULT               = false;
59     private static final long                      RANDOM_NUMBER_SEED_DEFAULT      = 21;
60     private static final boolean                   USE_LAST_DEFAULT                = false;
61     private boolean                                _return_internal_states         = false;
62     private boolean                                _return_gain_loss               = false;
63     private int                                    _total_gains;
64     private int                                    _total_losses;
65     private int                                    _total_unchanged;
66     private CharacterStateMatrix<List<STATE_TYPE>> _internal_states_matrix_prior_to_traceback;
67     private CharacterStateMatrix<STATE_TYPE>       _internal_states_matrix_after_traceback;
68     private CharacterStateMatrix<GainLossStates>   _gain_loss_matrix;
69     private boolean                                _randomize;
70     private boolean                                _use_last;
71     private int                                    _cost;
72     private long                                   _random_number_seed;
73     private Random                                 _random_generator;
74
75     public FitchParsimony() {
76         init();
77     }
78
79     private int determineIndex( final SortedSet<STATE_TYPE> current_node_states, int i ) {
80         if ( isRandomize() ) {
81             i = getRandomGenerator().nextInt( current_node_states.size() );
82         }
83         else if ( isUseLast() ) {
84             i = current_node_states.size() - 1;
85         }
86         return i;
87     }
88
89     public void execute( final Phylogeny p, final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
90         if ( !p.isRooted() ) {
91             throw new IllegalArgumentException( "attempt to execute Fitch parsimony on unroored phylogeny" );
92         }
93         if ( external_node_states_matrix.isEmpty() ) {
94             throw new IllegalArgumentException( "character matrix is empty" );
95         }
96         if ( external_node_states_matrix.getNumberOfIdentifiers() != p.getNumberOfExternalNodes() ) {
97             throw new IllegalArgumentException( "number of external nodes in phylogeny ["
98                     + p.getNumberOfExternalNodes() + "] and number of indentifiers ["
99                     + external_node_states_matrix.getNumberOfIdentifiers() + "] in matrix are not equal" );
100         }
101         reset();
102         if ( isReturnInternalStates() ) {
103             initializeInternalStates( p, external_node_states_matrix );
104         }
105         if ( isReturnGainLossMatrix() ) {
106             initializeGainLossMatrix( p, external_node_states_matrix );
107         }
108         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
109             executeForOneCharacter( p,
110                                     getStatesForCharacter( p, external_node_states_matrix, character_index ),
111                                     getStatesForCharacterForTraceback( p, external_node_states_matrix, character_index ),
112                                     character_index );
113         }
114         if ( external_node_states_matrix.getState( 0, 0 ) instanceof BinaryStates ) {
115             if ( ( external_node_states_matrix.getNumberOfCharacters() * p.getNumberOfBranches() ) != ( getTotalGains()
116                     + getTotalLosses() + getTotalUnchanged() ) ) {
117                 throw new FailedConditionCheckException( "this should not have happened: something is deeply wrong with Fitch parsimony implementation" );
118             }
119         }
120     }
121
122     private void executeForOneCharacter( final Phylogeny p,
123                                          final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
124                                          final Map<PhylogenyNode, STATE_TYPE> traceback_states,
125                                          final int character_state_column ) {
126         postOrderTraversal( p, states );
127         preOrderTraversal( p, states, traceback_states, character_state_column );
128     }
129
130     public int getCost() {
131         return _cost;
132     }
133
134     public CharacterStateMatrix<CharacterStateMatrix.GainLossStates> getGainLossMatrix() {
135         if ( !isReturnGainLossMatrix() ) {
136             throw new RuntimeException( "creation of gain-loss matrix has not been enabled" );
137         }
138         return _gain_loss_matrix;
139     }
140
141     public CharacterStateMatrix<STATE_TYPE> getInternalStatesMatrix() {
142         if ( !isReturnInternalStates() ) {
143             throw new RuntimeException( "creation of internal state matrix has not been enabled" );
144         }
145         return _internal_states_matrix_after_traceback;
146     }
147
148     /**
149      * Returns a view of the internal states prior to trace-back.
150      * 
151      * @return
152      */
153     public CharacterStateMatrix<List<STATE_TYPE>> getInternalStatesMatrixPriorToTraceback() {
154         if ( !isReturnInternalStates() ) {
155             throw new RuntimeException( "creation of internal state matrix has not been enabled" );
156         }
157         return _internal_states_matrix_prior_to_traceback;
158     }
159
160     private SortedSet<STATE_TYPE> getIntersectionOfStatesOfChildNodes( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
161                                                                        final PhylogenyNode node ) throws AssertionError {
162         final SortedSet<STATE_TYPE> states_in_child_nodes = new TreeSet<STATE_TYPE>();
163         for( int i = 0; i < node.getNumberOfDescendants(); ++i ) {
164             final PhylogenyNode node_child = node.getChildNode( i );
165             if ( !states.containsKey( node_child ) ) {
166                 throw new AssertionError( "this should not have happened: node [" + node_child.getName()
167                         + "] not found in node state map" );
168             }
169             if ( i == 0 ) {
170                 states_in_child_nodes.addAll( states.get( node_child ) );
171             }
172             else {
173                 states_in_child_nodes.retainAll( states.get( node_child ) );
174             }
175         }
176         return states_in_child_nodes;
177     }
178
179     private Random getRandomGenerator() {
180         return _random_generator;
181     }
182
183     private long getRandomNumberSeed() {
184         return _random_number_seed;
185     }
186
187     private STATE_TYPE getStateAt( final int i, final SortedSet<STATE_TYPE> states ) {
188         final Iterator<STATE_TYPE> it = states.iterator();
189         for( int j = 0; j < i; ++j ) {
190             it.next();
191         }
192         return it.next();
193     }
194
195     private Map<PhylogenyNode, SortedSet<STATE_TYPE>> getStatesForCharacter( final Phylogeny p,
196                                                                              final CharacterStateMatrix<STATE_TYPE> matrix,
197                                                                              final int character_index ) {
198         final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states = new HashMap<PhylogenyNode, SortedSet<STATE_TYPE>>( matrix
199                 .getNumberOfIdentifiers() );
200         for( int indentifier_index = 0; indentifier_index < matrix.getNumberOfIdentifiers(); ++indentifier_index ) {
201             final STATE_TYPE state = matrix.getState( indentifier_index, character_index );
202             if ( state == null ) {
203                 throw new IllegalArgumentException( "value at [" + indentifier_index + ", " + character_index
204                         + "] is null" );
205             }
206             final SortedSet<STATE_TYPE> l = new TreeSet<STATE_TYPE>();
207             l.add( state );
208             states.put( p.getNode( matrix.getIdentifier( indentifier_index ) ), l );
209         }
210         return states;
211     }
212
213     private Map<PhylogenyNode, STATE_TYPE> getStatesForCharacterForTraceback( final Phylogeny p,
214                                                                               final CharacterStateMatrix<STATE_TYPE> matrix,
215                                                                               final int character_index ) {
216         final Map<PhylogenyNode, STATE_TYPE> states = new HashMap<PhylogenyNode, STATE_TYPE>( matrix
217                 .getNumberOfIdentifiers() );
218         for( int indentifier_index = 0; indentifier_index < matrix.getNumberOfIdentifiers(); ++indentifier_index ) {
219             final STATE_TYPE state = matrix.getState( indentifier_index, character_index );
220             if ( state == null ) {
221                 throw new IllegalArgumentException( "value at [" + indentifier_index + ", " + character_index
222                         + "] is null" );
223             }
224             states.put( p.getNode( matrix.getIdentifier( indentifier_index ) ), state );
225         }
226         return states;
227     }
228
229     public int getTotalGains() {
230         return _total_gains;
231     }
232
233     public int getTotalLosses() {
234         return _total_losses;
235     }
236
237     public int getTotalUnchanged() {
238         return _total_unchanged;
239     }
240
241     private SortedSet<STATE_TYPE> getUnionOfStatesOfChildNodes( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
242                                                                 final PhylogenyNode node ) throws AssertionError {
243         final SortedSet<STATE_TYPE> states_in_child_nodes = new TreeSet<STATE_TYPE>();
244         for( int i = 0; i < node.getNumberOfDescendants(); ++i ) {
245             final PhylogenyNode node_child = node.getChildNode( i );
246             if ( !states.containsKey( node_child ) ) {
247                 throw new AssertionError( "this should not have happened: node [" + node_child.getName()
248                         + "] not found in node state map" );
249             }
250             states_in_child_nodes.addAll( states.get( node_child ) );
251         }
252         return states_in_child_nodes;
253     }
254
255     private void increaseCost() {
256         ++_cost;
257     }
258
259     private void init() {
260         setReturnInternalStates( RETURN_INTERNAL_STATES_DEFAULT );
261         setReturnGainLossMatrix( RETURN_GAIN_LOSS_MATRIX_DEFAULT );
262         setRandomize( RANDOMIZE_DEFAULT );
263         setUseLast( USE_LAST_DEFAULT );
264         _random_number_seed = RANDOM_NUMBER_SEED_DEFAULT;
265         reset();
266     }
267
268     private void initializeGainLossMatrix( final Phylogeny p,
269                                            final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
270         final List<PhylogenyNode> nodes = new ArrayList<PhylogenyNode>();
271         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
272             nodes.add( postorder.next() );
273         }
274         setGainLossMatrix( new BasicCharacterStateMatrix<CharacterStateMatrix.GainLossStates>( nodes.size(),
275                                                                                                external_node_states_matrix
276                                                                                                        .getNumberOfCharacters() ) );
277         int identifier_index = 0;
278         for( final PhylogenyNode node : nodes ) {
279             getGainLossMatrix().setIdentifier( identifier_index++,
280                                                ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node
281                                                        .getName() );
282         }
283         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
284             getGainLossMatrix().setCharacter( character_index,
285                                               external_node_states_matrix.getCharacter( character_index ) );
286         }
287     }
288
289     private void initializeInternalStates( final Phylogeny p,
290                                            final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
291         final List<PhylogenyNode> internal_nodes = new ArrayList<PhylogenyNode>();
292         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
293             final PhylogenyNode node = postorder.next();
294             if ( node.isInternal() ) {
295                 internal_nodes.add( node );
296             }
297         }
298         setInternalStatesMatrixPriorToTraceback( new BasicCharacterStateMatrix<List<STATE_TYPE>>( internal_nodes.size(),
299                                                                                                   external_node_states_matrix
300                                                                                                           .getNumberOfCharacters() ) );
301         setInternalStatesMatrixTraceback( new BasicCharacterStateMatrix<STATE_TYPE>( internal_nodes.size(),
302                                                                                      external_node_states_matrix
303                                                                                              .getNumberOfCharacters() ) );
304         int identifier_index = 0;
305         for( final PhylogenyNode node : internal_nodes ) {
306             getInternalStatesMatrix().setIdentifier( identifier_index,
307                                                      ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node
308                                                              .getName() );
309             getInternalStatesMatrixPriorToTraceback().setIdentifier( identifier_index,
310                                                                      ForesterUtil.isEmpty( node.getName() ) ? node
311                                                                              .getId()
312                                                                              + "" : node.getName() );
313             ++identifier_index;
314         }
315         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
316             getInternalStatesMatrix().setCharacter( character_index,
317                                                     external_node_states_matrix.getCharacter( character_index ) );
318             getInternalStatesMatrixPriorToTraceback().setCharacter( character_index,
319                                                                     external_node_states_matrix
320                                                                             .getCharacter( character_index ) );
321         }
322     }
323
324     private boolean isRandomize() {
325         return _randomize;
326     }
327
328     private boolean isReturnGainLossMatrix() {
329         return _return_gain_loss;
330     }
331
332     private boolean isReturnInternalStates() {
333         return _return_internal_states;
334     }
335
336     private boolean isUseLast() {
337         return _use_last;
338     }
339
340     private void postOrderTraversal( final Phylogeny p, final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states )
341             throws AssertionError {
342         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
343             final PhylogenyNode node = postorder.next();
344             if ( !node.isExternal() ) {
345                 SortedSet<STATE_TYPE> states_in_children = getIntersectionOfStatesOfChildNodes( states, node );
346                 if ( states_in_children.isEmpty() ) {
347                     states_in_children = getUnionOfStatesOfChildNodes( states, node );
348                 }
349                 states.put( node, states_in_children );
350             }
351         }
352     }
353
354     private void preOrderTraversal( final Phylogeny p,
355                                     final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
356                                     final Map<PhylogenyNode, STATE_TYPE> traceback_states,
357                                     final int character_state_column ) throws AssertionError {
358         for( final PhylogenyNodeIterator preorder = p.iteratorPreorder(); preorder.hasNext(); ) {
359             final PhylogenyNode current_node = preorder.next();
360             final SortedSet<STATE_TYPE> current_node_states = states.get( current_node );
361             STATE_TYPE parent_state = null;
362             if ( current_node.isRoot() ) {
363                 int i = 0;
364                 i = determineIndex( current_node_states, i );
365                 traceback_states.put( current_node, getStateAt( i, current_node_states ) );
366             }
367             else {
368                 parent_state = traceback_states.get( current_node.getParent() );
369                 if ( current_node_states.contains( parent_state ) ) {
370                     traceback_states.put( current_node, parent_state );
371                 }
372                 else {
373                     increaseCost();
374                     int i = 0;
375                     i = determineIndex( current_node_states, i );
376                     traceback_states.put( current_node, getStateAt( i, current_node_states ) );
377                 }
378             }
379             if ( isReturnInternalStates() ) {
380                 if ( !current_node.isExternal() ) {
381                     setInternalNodeStatePriorToTraceback( states, character_state_column, current_node );
382                     setInternalNodeState( traceback_states, character_state_column, current_node );
383                 }
384             }
385             if ( isReturnGainLossMatrix() && !current_node.isRoot() ) {
386                 if ( !( parent_state instanceof BinaryStates ) ) {
387                     throw new RuntimeException( "attempt to create gain loss matrix for not binary states" );
388                 }
389                 final BinaryStates parent_binary_state = ( BinaryStates ) parent_state;
390                 final BinaryStates current_binary_state = ( BinaryStates ) traceback_states.get( current_node );
391                 if ( ( parent_binary_state == PRESENT ) && ( current_binary_state == ABSENT ) ) {
392                     ++_total_losses;
393                     setGainLossState( character_state_column, current_node, LOSS );
394                 }
395                 else if ( ( ( parent_binary_state == ABSENT ) || ( parent_binary_state == null ) )
396                         && ( current_binary_state == PRESENT ) ) {
397                     ++_total_gains;
398                     setGainLossState( character_state_column, current_node, GAIN );
399                 }
400                 else {
401                     ++_total_unchanged;
402                     if ( current_binary_state == PRESENT ) {
403                         setGainLossState( character_state_column, current_node, UNCHANGED_PRESENT );
404                     }
405                     else if ( current_binary_state == ABSENT ) {
406                         setGainLossState( character_state_column, current_node, UNCHANGED_ABSENT );
407                     }
408                 }
409             }
410             else if ( isReturnGainLossMatrix() && current_node.isRoot() ) {
411                 final BinaryStates current_binary_state = ( BinaryStates ) traceback_states.get( current_node );
412                 ++_total_unchanged; //new
413                 if ( current_binary_state == PRESENT ) {//new
414                     setGainLossState( character_state_column, current_node, UNCHANGED_PRESENT );//new
415                 }//new
416                 else if ( current_binary_state == ABSENT ) {//new
417                     setGainLossState( character_state_column, current_node, UNCHANGED_ABSENT );//new
418                 }//new
419                 // setGainLossState( character_state_column, current_node, UNKNOWN_GAIN_LOSS );
420             }
421         }
422     }
423
424     private void reset() {
425         setCost( 0 );
426         setTotalLosses( 0 );
427         setTotalGains( 0 );
428         setTotalUnchanged( 0 );
429         setRandomGenerator( new Random( getRandomNumberSeed() ) );
430     }
431
432     private void setCost( final int cost ) {
433         _cost = cost;
434     }
435
436     private void setGainLossMatrix( final CharacterStateMatrix<GainLossStates> gain_loss_matrix ) {
437         _gain_loss_matrix = gain_loss_matrix;
438     }
439
440     private void setGainLossState( final int character_state_column,
441                                    final PhylogenyNode node,
442                                    final GainLossStates state ) {
443         getGainLossMatrix().setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node.getName(),
444                                       character_state_column,
445                                       state );
446     }
447
448     private void setInternalNodeState( final Map<PhylogenyNode, STATE_TYPE> states,
449                                        final int character_state_column,
450                                        final PhylogenyNode node ) {
451         getInternalStatesMatrix()
452                 .setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node.getName(),
453                            character_state_column,
454                            states.get( node ) );
455     }
456
457     private void setInternalNodeStatePriorToTraceback( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
458                                                        final int character_state_column,
459                                                        final PhylogenyNode node ) {
460         getInternalStatesMatrixPriorToTraceback().setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + ""
461                                                                     : node.getName(),
462                                                             character_state_column,
463                                                             toListSorted( states.get( node ) ) );
464     }
465
466     private void setInternalStatesMatrixPriorToTraceback( final CharacterStateMatrix<List<STATE_TYPE>> internal_states_matrix_prior_to_traceback ) {
467         _internal_states_matrix_prior_to_traceback = internal_states_matrix_prior_to_traceback;
468     }
469
470     private void setInternalStatesMatrixTraceback( final CharacterStateMatrix<STATE_TYPE> internal_states_matrix_after_traceback ) {
471         _internal_states_matrix_after_traceback = internal_states_matrix_after_traceback;
472     }
473
474     private void setRandomGenerator( final Random random_generator ) {
475         _random_generator = random_generator;
476     }
477
478     public void setRandomize( final boolean randomize ) {
479         if ( randomize && isUseLast() ) {
480             throw new IllegalArgumentException( "attempt to allways use last state (ordered) if more than one choices and randomization at the same time" );
481         }
482         _randomize = randomize;
483     }
484
485     public void setRandomNumberSeed( final long random_number_seed ) {
486         if ( !isRandomize() ) {
487             throw new IllegalArgumentException( "attempt to set random number generator seed without randomization enabled" );
488         }
489         _random_number_seed = random_number_seed;
490     }
491
492     public void setReturnGainLossMatrix( final boolean return_gain_loss ) {
493         _return_gain_loss = return_gain_loss;
494     }
495
496     public void setReturnInternalStates( final boolean return_internal_states ) {
497         _return_internal_states = return_internal_states;
498     }
499
500     private void setTotalGains( final int total_gains ) {
501         _total_gains = total_gains;
502     }
503
504     private void setTotalLosses( final int total_losses ) {
505         _total_losses = total_losses;
506     }
507
508     private void setTotalUnchanged( final int total_unchanged ) {
509         _total_unchanged = total_unchanged;
510     }
511
512     /**
513      * This sets whether to use the first or last state in the sorted
514      * states at the undecided internal nodes.
515      * For randomized choices set randomize to true (and this to false).
516      * 
517      * Note. It might be advisable to set this to false
518      * for BinaryStates if absence at the root is preferred
519      * (given the enum BinaryStates sorts in the following order: 
520      * ABSENT, UNKNOWN, PRESENT).
521      * 
522      * 
523      * @param use_last
524      */
525     public void setUseLast( final boolean use_last ) {
526         if ( use_last && isRandomize() ) {
527             throw new IllegalArgumentException( "attempt to allways use last state (ordered) if more than one choices and randomization at the same time" );
528         }
529         _use_last = use_last;
530     }
531
532     private List<STATE_TYPE> toListSorted( final SortedSet<STATE_TYPE> states ) {
533         final List<STATE_TYPE> l = new ArrayList<STATE_TYPE>( states.size() );
534         for( final STATE_TYPE state : states ) {
535             l.add( state );
536         }
537         return l;
538     }
539 }