initial commit
[jalview.git] / forester / java / src / org / forester / evoinference / parsimony / SankoffParsimony.java
1 // $Id:
2 //
3 // FORESTER -- software libraries and applications
4 // for evolutionary biology research and applications.
5 //
6 // Copyright (C) 2008-2009 Christian M. Zmasek
7 // Copyright (C) 2008-2009 Burnham Institute for Medical Research
8 // All rights reserved
9 // 
10 // This library is free software; you can redistribute it and/or
11 // modify it under the terms of the GNU Lesser General Public
12 // License as published by the Free Software Foundation; either
13 // version 2.1 of the License, or (at your option) any later version.
14 //
15 // This library is distributed in the hope that it will be useful,
16 // but WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18 // Lesser General Public License for more details.
19 // 
20 // You should have received a copy of the GNU Lesser General Public
21 // License along with this library; if not, write to the Free Software
22 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
23 //
24 // Contact: phylosoft @ gmail . com
25 // WWW: www.phylosoft.org
26
27 package org.forester.evoinference.parsimony;
28
29 import java.util.ArrayList;
30 import java.util.HashMap;
31 import java.util.Iterator;
32 import java.util.List;
33 import java.util.Map;
34 import java.util.Random;
35 import java.util.SortedSet;
36 import java.util.TreeSet;
37
38 import org.forester.evoinference.matrix.character.BasicCharacterStateMatrix;
39 import org.forester.evoinference.matrix.character.CharacterStateMatrix;
40 import org.forester.evoinference.matrix.character.CharacterStateMatrix.BinaryStates;
41 import org.forester.evoinference.matrix.character.CharacterStateMatrix.GainLossStates;
42 import org.forester.phylogeny.Phylogeny;
43 import org.forester.phylogeny.PhylogenyNode;
44 import org.forester.phylogeny.iterators.PhylogenyNodeIterator;
45 import org.forester.util.ForesterUtil;
46
47 /**
48  * 
49  * IN PROGRESS!
50  * DO NOT USE!
51  * 
52  * 
53  * @param <STATE_TYPE>
54  */
55 public class SankoffParsimony<STATE_TYPE> {
56
57     final static private BinaryStates              PRESENT                         = BinaryStates.PRESENT;
58     final static private BinaryStates              ABSENT                          = BinaryStates.ABSENT;
59     final static private GainLossStates            LOSS                            = GainLossStates.LOSS;
60     final static private GainLossStates            GAIN                            = GainLossStates.GAIN;
61     final static private GainLossStates            UNCHANGED_PRESENT               = GainLossStates.UNCHANGED_PRESENT;
62     final static private GainLossStates            UNCHANGED_ABSENT                = GainLossStates.UNCHANGED_ABSENT;
63     private static final boolean                   RETURN_INTERNAL_STATES_DEFAULT  = false;
64     private static final boolean                   RETURN_GAIN_LOSS_MATRIX_DEFAULT = false;
65     private static final boolean                   RANDOMIZE_DEFAULT               = false;
66     private static final long                      RANDOM_NUMBER_SEED_DEFAULT      = 21;
67     private static final boolean                   USE_LAST_DEFAULT                = false;
68     private boolean                                _return_internal_states         = false;
69     private boolean                                _return_gain_loss               = false;
70     private int                                    _total_gains;
71     private int                                    _total_losses;
72     private int                                    _total_unchanged;
73     private CharacterStateMatrix<List<STATE_TYPE>> _internal_states_matrix_prior_to_traceback;
74     private CharacterStateMatrix<STATE_TYPE>       _internal_states_matrix_after_traceback;
75     private CharacterStateMatrix<GainLossStates>   _gain_loss_matrix;
76     private boolean                                _randomize;
77     private boolean                                _use_last;
78     private int                                    _cost;
79     private long                                   _random_number_seed;
80     private Random                                 _random_generator;
81
82     public SankoffParsimony() {
83         init();
84     }
85
86     private int determineIndex( final SortedSet<STATE_TYPE> current_node_states, int i ) {
87         if ( isRandomize() ) {
88             i = getRandomGenerator().nextInt( current_node_states.size() );
89         }
90         else if ( isUseLast() ) {
91             i = current_node_states.size() - 1;
92         }
93         return i;
94     }
95
96     public void execute( final Phylogeny p, final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
97         if ( !p.isRooted() ) {
98             throw new IllegalArgumentException( "attempt to execute Fitch parsimony on unroored phylogeny" );
99         }
100         if ( external_node_states_matrix.isEmpty() ) {
101             throw new IllegalArgumentException( "character matrix is empty" );
102         }
103         if ( external_node_states_matrix.getNumberOfIdentifiers() != p.getNumberOfExternalNodes() ) {
104             throw new IllegalArgumentException( "number of external nodes in phylogeny ["
105                     + p.getNumberOfExternalNodes() + "] and number of indentifiers ["
106                     + external_node_states_matrix.getNumberOfIdentifiers() + "] in matrix are not equal" );
107         }
108         reset();
109         if ( isReturnInternalStates() ) {
110             initializeInternalStates( p, external_node_states_matrix );
111         }
112         if ( isReturnGainLossMatrix() ) {
113             initializeGainLossMatrix( p, external_node_states_matrix );
114         }
115         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
116             executeForOneCharacter( p,
117                                     getStatesForCharacter( p, external_node_states_matrix, character_index ),
118                                     getStatesForCharacterForTraceback( p, external_node_states_matrix, character_index ),
119                                     character_index );
120         }
121         if ( external_node_states_matrix.getState( 0, 0 ) instanceof BinaryStates ) {
122             if ( ( external_node_states_matrix.getNumberOfCharacters() * p.getNumberOfBranches() ) != ( getTotalGains()
123                     + getTotalLosses() + getTotalUnchanged() ) ) {
124                 throw new RuntimeException( "this should not have happened: something is deeply wrong with Fitch parsimony implementation" );
125             }
126         }
127     }
128
129     private void executeForOneCharacter( final Phylogeny p,
130                                          final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
131                                          final Map<PhylogenyNode, STATE_TYPE> traceback_states,
132                                          final int character_state_column ) {
133         postOrderTraversal( p, states );
134         preOrderTraversal( p, states, traceback_states, character_state_column );
135     }
136
137     public int getCost() {
138         return _cost;
139     }
140
141     public CharacterStateMatrix<CharacterStateMatrix.GainLossStates> getGainLossMatrix() {
142         if ( !isReturnGainLossMatrix() ) {
143             throw new RuntimeException( "creation of gain-loss matrix has not been enabled" );
144         }
145         return _gain_loss_matrix;
146     }
147
148     public CharacterStateMatrix<STATE_TYPE> getInternalStatesMatrix() {
149         if ( !isReturnInternalStates() ) {
150             throw new RuntimeException( "creation of internal state matrix has not been enabled" );
151         }
152         return _internal_states_matrix_after_traceback;
153     }
154
155     /**
156      * Returns a view of the internal states prior to trace-back.
157      * 
158      * @return
159      */
160     public CharacterStateMatrix<List<STATE_TYPE>> getInternalStatesMatrixPriorToTraceback() {
161         if ( !isReturnInternalStates() ) {
162             throw new RuntimeException( "creation of internal state matrix has not been enabled" );
163         }
164         return _internal_states_matrix_prior_to_traceback;
165     }
166
167     private SortedSet<STATE_TYPE> getIntersectionOfStatesOfChildNodes( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
168                                                                        final PhylogenyNode node ) throws AssertionError {
169         final SortedSet<STATE_TYPE> states_in_child_nodes = new TreeSet<STATE_TYPE>();
170         for( int i = 0; i < node.getNumberOfDescendants(); ++i ) {
171             final PhylogenyNode node_child = node.getChildNode( i );
172             if ( !states.containsKey( node_child ) ) {
173                 throw new AssertionError( "this should not have happened: node [" + node_child.getName()
174                         + "] not found in node state map" );
175             }
176             if ( i == 0 ) {
177                 states_in_child_nodes.addAll( states.get( node_child ) );
178             }
179             else {
180                 states_in_child_nodes.retainAll( states.get( node_child ) );
181             }
182         }
183         return states_in_child_nodes;
184     }
185
186     private Random getRandomGenerator() {
187         return _random_generator;
188     }
189
190     private long getRandomNumberSeed() {
191         return _random_number_seed;
192     }
193
194     private STATE_TYPE getStateAt( final int i, final SortedSet<STATE_TYPE> states ) {
195         final Iterator<STATE_TYPE> it = states.iterator();
196         for( int j = 0; j < i; ++j ) {
197             it.next();
198         }
199         return it.next();
200     }
201
202     private Map<PhylogenyNode, SortedSet<STATE_TYPE>> getStatesForCharacter( final Phylogeny p,
203                                                                              final CharacterStateMatrix<STATE_TYPE> matrix,
204                                                                              final int character_index ) {
205         final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states = new HashMap<PhylogenyNode, SortedSet<STATE_TYPE>>( matrix
206                 .getNumberOfIdentifiers() );
207         for( int indentifier_index = 0; indentifier_index < matrix.getNumberOfIdentifiers(); ++indentifier_index ) {
208             final STATE_TYPE state = matrix.getState( indentifier_index, character_index );
209             if ( state == null ) {
210                 throw new IllegalArgumentException( "value at [" + indentifier_index + ", " + character_index
211                         + "] is null" );
212             }
213             final SortedSet<STATE_TYPE> l = new TreeSet<STATE_TYPE>();
214             l.add( state );
215             states.put( p.getNode( matrix.getIdentifier( indentifier_index ) ), l );
216         }
217         return states;
218     }
219
220     private Map<PhylogenyNode, STATE_TYPE> getStatesForCharacterForTraceback( final Phylogeny p,
221                                                                               final CharacterStateMatrix<STATE_TYPE> matrix,
222                                                                               final int character_index ) {
223         final Map<PhylogenyNode, STATE_TYPE> states = new HashMap<PhylogenyNode, STATE_TYPE>( matrix
224                 .getNumberOfIdentifiers() );
225         for( int indentifier_index = 0; indentifier_index < matrix.getNumberOfIdentifiers(); ++indentifier_index ) {
226             final STATE_TYPE state = matrix.getState( indentifier_index, character_index );
227             if ( state == null ) {
228                 throw new IllegalArgumentException( "value at [" + indentifier_index + ", " + character_index
229                         + "] is null" );
230             }
231             states.put( p.getNode( matrix.getIdentifier( indentifier_index ) ), state );
232         }
233         return states;
234     }
235
236     public int getTotalGains() {
237         return _total_gains;
238     }
239
240     public int getTotalLosses() {
241         return _total_losses;
242     }
243
244     public int getTotalUnchanged() {
245         return _total_unchanged;
246     }
247
248     private SortedSet<STATE_TYPE> getUnionOfStatesOfChildNodes( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
249                                                                 final PhylogenyNode node ) throws AssertionError {
250         final SortedSet<STATE_TYPE> states_in_child_nodes = new TreeSet<STATE_TYPE>();
251         for( int i = 0; i < node.getNumberOfDescendants(); ++i ) {
252             final PhylogenyNode node_child = node.getChildNode( i );
253             if ( !states.containsKey( node_child ) ) {
254                 throw new AssertionError( "this should not have happened: node [" + node_child.getName()
255                         + "] not found in node state map" );
256             }
257             states_in_child_nodes.addAll( states.get( node_child ) );
258         }
259         return states_in_child_nodes;
260     }
261
262     private void increaseCost() {
263         ++_cost;
264     }
265
266     private void init() {
267         setReturnInternalStates( RETURN_INTERNAL_STATES_DEFAULT );
268         setReturnGainLossMatrix( RETURN_GAIN_LOSS_MATRIX_DEFAULT );
269         setRandomize( RANDOMIZE_DEFAULT );
270         setUseLast( USE_LAST_DEFAULT );
271         _random_number_seed = RANDOM_NUMBER_SEED_DEFAULT;
272         reset();
273     }
274
275     private void initializeGainLossMatrix( final Phylogeny p,
276                                            final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
277         final List<PhylogenyNode> nodes = new ArrayList<PhylogenyNode>();
278         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
279             nodes.add( postorder.next() );
280         }
281         setGainLossMatrix( new BasicCharacterStateMatrix<CharacterStateMatrix.GainLossStates>( nodes.size(),
282                                                                                                external_node_states_matrix
283                                                                                                        .getNumberOfCharacters() ) );
284         int identifier_index = 0;
285         for( final PhylogenyNode node : nodes ) {
286             getGainLossMatrix().setIdentifier( identifier_index++,
287                                                ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node
288                                                        .getName() );
289         }
290         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
291             getGainLossMatrix().setCharacter( character_index,
292                                               external_node_states_matrix.getCharacter( character_index ) );
293         }
294     }
295
296     private void initializeInternalStates( final Phylogeny p,
297                                            final CharacterStateMatrix<STATE_TYPE> external_node_states_matrix ) {
298         final List<PhylogenyNode> internal_nodes = new ArrayList<PhylogenyNode>();
299         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
300             final PhylogenyNode node = postorder.next();
301             if ( node.isInternal() ) {
302                 internal_nodes.add( node );
303             }
304         }
305         setInternalStatesMatrixPriorToTraceback( new BasicCharacterStateMatrix<List<STATE_TYPE>>( internal_nodes.size(),
306                                                                                                   external_node_states_matrix
307                                                                                                           .getNumberOfCharacters() ) );
308         setInternalStatesMatrixTraceback( new BasicCharacterStateMatrix<STATE_TYPE>( internal_nodes.size(),
309                                                                                      external_node_states_matrix
310                                                                                              .getNumberOfCharacters() ) );
311         int identifier_index = 0;
312         for( final PhylogenyNode node : internal_nodes ) {
313             getInternalStatesMatrix().setIdentifier( identifier_index,
314                                                      ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node
315                                                              .getName() );
316             getInternalStatesMatrixPriorToTraceback().setIdentifier( identifier_index,
317                                                                      ForesterUtil.isEmpty( node.getName() ) ? node
318                                                                              .getId()
319                                                                              + "" : node.getName() );
320             ++identifier_index;
321         }
322         for( int character_index = 0; character_index < external_node_states_matrix.getNumberOfCharacters(); ++character_index ) {
323             getInternalStatesMatrix().setCharacter( character_index,
324                                                     external_node_states_matrix.getCharacter( character_index ) );
325             getInternalStatesMatrixPriorToTraceback().setCharacter( character_index,
326                                                                     external_node_states_matrix
327                                                                             .getCharacter( character_index ) );
328         }
329     }
330
331     private boolean isRandomize() {
332         return _randomize;
333     }
334
335     private boolean isReturnGainLossMatrix() {
336         return _return_gain_loss;
337     }
338
339     private boolean isReturnInternalStates() {
340         return _return_internal_states;
341     }
342
343     private boolean isUseLast() {
344         return _use_last;
345     }
346
347     private void postOrderTraversal( final Phylogeny p, final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states )
348             throws AssertionError {
349         for( final PhylogenyNodeIterator postorder = p.iteratorPostorder(); postorder.hasNext(); ) {
350             final PhylogenyNode node = postorder.next();
351             if ( !node.isExternal() ) {
352                 SortedSet<STATE_TYPE> states_in_children = getIntersectionOfStatesOfChildNodes( states, node );
353                 if ( states_in_children.isEmpty() ) {
354                     states_in_children = getUnionOfStatesOfChildNodes( states, node );
355                 }
356                 states.put( node, states_in_children );
357             }
358         }
359     }
360
361     private void preOrderTraversal( final Phylogeny p,
362                                     final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
363                                     final Map<PhylogenyNode, STATE_TYPE> traceback_states,
364                                     final int character_state_column ) throws AssertionError {
365         for( final PhylogenyNodeIterator preorder = p.iteratorPreorder(); preorder.hasNext(); ) {
366             final PhylogenyNode current_node = preorder.next();
367             final SortedSet<STATE_TYPE> current_node_states = states.get( current_node );
368             STATE_TYPE parent_state = null;
369             if ( current_node.isRoot() ) {
370                 int i = 0;
371                 i = determineIndex( current_node_states, i );
372                 traceback_states.put( current_node, getStateAt( i, current_node_states ) );
373             }
374             else {
375                 parent_state = traceback_states.get( current_node.getParent() );
376                 if ( current_node_states.contains( parent_state ) ) {
377                     traceback_states.put( current_node, parent_state );
378                 }
379                 else {
380                     increaseCost();
381                     int i = 0;
382                     i = determineIndex( current_node_states, i );
383                     traceback_states.put( current_node, getStateAt( i, current_node_states ) );
384                 }
385             }
386             if ( isReturnInternalStates() ) {
387                 if ( !current_node.isExternal() ) {
388                     setInternalNodeStatePriorToTraceback( states, character_state_column, current_node );
389                     setInternalNodeState( traceback_states, character_state_column, current_node );
390                 }
391             }
392             if ( isReturnGainLossMatrix() && !current_node.isRoot() ) {
393                 if ( !( parent_state instanceof BinaryStates ) ) {
394                     throw new RuntimeException( "attempt to create gain loss matrix for not binary states" );
395                 }
396                 final BinaryStates parent_binary_state = ( BinaryStates ) parent_state;
397                 final BinaryStates current_binary_state = ( BinaryStates ) traceback_states.get( current_node );
398                 if ( ( parent_binary_state == PRESENT ) && ( current_binary_state == ABSENT ) ) {
399                     ++_total_losses;
400                     setGainLossState( character_state_column, current_node, LOSS );
401                 }
402                 else if ( ( ( parent_binary_state == ABSENT ) || ( parent_binary_state == null ) )
403                         && ( current_binary_state == PRESENT ) ) {
404                     ++_total_gains;
405                     setGainLossState( character_state_column, current_node, GAIN );
406                 }
407                 else {
408                     ++_total_unchanged;
409                     if ( current_binary_state == PRESENT ) {
410                         setGainLossState( character_state_column, current_node, UNCHANGED_PRESENT );
411                     }
412                     else if ( current_binary_state == ABSENT ) {
413                         setGainLossState( character_state_column, current_node, UNCHANGED_ABSENT );
414                     }
415                 }
416             }
417             else if ( isReturnGainLossMatrix() && current_node.isRoot() ) {
418                 final BinaryStates current_binary_state = ( BinaryStates ) traceback_states.get( current_node );
419                 ++_total_unchanged; //new
420                 if ( current_binary_state == PRESENT ) {//new
421                     setGainLossState( character_state_column, current_node, UNCHANGED_PRESENT );//new
422                 }//new
423                 else if ( current_binary_state == ABSENT ) {//new
424                     setGainLossState( character_state_column, current_node, UNCHANGED_ABSENT );//new
425                 }//new
426                 // setGainLossState( character_state_column, current_node, UNKNOWN_GAIN_LOSS );
427             }
428         }
429     }
430
431     private void reset() {
432         setCost( 0 );
433         setTotalLosses( 0 );
434         setTotalGains( 0 );
435         setTotalUnchanged( 0 );
436         setRandomGenerator( new Random( getRandomNumberSeed() ) );
437     }
438
439     private void setCost( final int cost ) {
440         _cost = cost;
441     }
442
443     private void setGainLossMatrix( final CharacterStateMatrix<GainLossStates> gain_loss_matrix ) {
444         _gain_loss_matrix = gain_loss_matrix;
445     }
446
447     private void setGainLossState( final int character_state_column,
448                                    final PhylogenyNode node,
449                                    final GainLossStates state ) {
450         getGainLossMatrix().setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node.getName(),
451                                       character_state_column,
452                                       state );
453     }
454
455     private void setInternalNodeState( final Map<PhylogenyNode, STATE_TYPE> states,
456                                        final int character_state_column,
457                                        final PhylogenyNode node ) {
458         getInternalStatesMatrix()
459                 .setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + "" : node.getName(),
460                            character_state_column,
461                            states.get( node ) );
462     }
463
464     private void setInternalNodeStatePriorToTraceback( final Map<PhylogenyNode, SortedSet<STATE_TYPE>> states,
465                                                        final int character_state_column,
466                                                        final PhylogenyNode node ) {
467         getInternalStatesMatrixPriorToTraceback().setState( ForesterUtil.isEmpty( node.getName() ) ? node.getId() + ""
468                                                                     : node.getName(),
469                                                             character_state_column,
470                                                             toListSorted( states.get( node ) ) );
471     }
472
473     private void setInternalStatesMatrixPriorToTraceback( final CharacterStateMatrix<List<STATE_TYPE>> internal_states_matrix_prior_to_traceback ) {
474         _internal_states_matrix_prior_to_traceback = internal_states_matrix_prior_to_traceback;
475     }
476
477     private void setInternalStatesMatrixTraceback( final CharacterStateMatrix<STATE_TYPE> internal_states_matrix_after_traceback ) {
478         _internal_states_matrix_after_traceback = internal_states_matrix_after_traceback;
479     }
480
481     private void setRandomGenerator( final Random random_generator ) {
482         _random_generator = random_generator;
483     }
484
485     public void setRandomize( final boolean randomize ) {
486         if ( randomize && isUseLast() ) {
487             throw new IllegalArgumentException( "attempt to allways use last state (ordered) if more than one choices and randomization at the same time" );
488         }
489         _randomize = randomize;
490     }
491
492     public void setRandomNumberSeed( final long random_number_seed ) {
493         if ( !isRandomize() ) {
494             throw new IllegalArgumentException( "attempt to set random number generator seed without randomization enabled" );
495         }
496         _random_number_seed = random_number_seed;
497     }
498
499     public void setReturnGainLossMatrix( final boolean return_gain_loss ) {
500         _return_gain_loss = return_gain_loss;
501     }
502
503     public void setReturnInternalStates( final boolean return_internal_states ) {
504         _return_internal_states = return_internal_states;
505     }
506
507     private void setTotalGains( final int total_gains ) {
508         _total_gains = total_gains;
509     }
510
511     private void setTotalLosses( final int total_losses ) {
512         _total_losses = total_losses;
513     }
514
515     private void setTotalUnchanged( final int total_unchanged ) {
516         _total_unchanged = total_unchanged;
517     }
518
519     /**
520      * This sets whether to use the first or last state in the sorted
521      * states at the undecided internal nodes.
522      * For randomized choices set randomize to true (and this to false).
523      * 
524      * Note. It might be advisable to set this to false
525      * for BinaryStates if absence at the root is preferred
526      * (given the enum BinaryStates sorts in the following order: 
527      * ABSENT, UNKNOWN, PRESENT).
528      * 
529      * 
530      * @param use_last
531      */
532     public void setUseLast( final boolean use_last ) {
533         if ( use_last && isRandomize() ) {
534             throw new IllegalArgumentException( "attempt to allways use last state (ordered) if more than one choices and randomization at the same time" );
535         }
536         _use_last = use_last;
537     }
538
539     private List<STATE_TYPE> toListSorted( final SortedSet<STATE_TYPE> states ) {
540         final List<STATE_TYPE> l = new ArrayList<STATE_TYPE>( states.size() );
541         for( final STATE_TYPE state : states ) {
542             l.add( state );
543         }
544         return l;
545     }
546 }