improving GSDI, under construction...
[jalview.git] / forester / java / src / org / forester / sdi / GSDI.java
index 780aba9..3fea313 100644 (file)
@@ -63,15 +63,14 @@ import org.forester.util.ForesterUtil;
  */
 public final class GSDI extends SDI {
 
-    private final HashMap<PhylogenyNode, Integer> _transversal_counts;
-    private final boolean                         _most_parsimonious_duplication_model;
-    private final boolean                         _strip_gene_tree;
-    private final boolean                         _strip_species_tree;
-    private int                                   _speciation_or_duplication_events_sum;
-    private int                                   _speciations_sum;
-    private final List<PhylogenyNode>             _stripped_gene_tree_nodes;
-    private final List<PhylogenyNode>             _stripped_species_tree_nodes;
-    private final Set<PhylogenyNode>              _mapped_species_tree_nodes;
+    private final boolean             _most_parsimonious_duplication_model;
+    private final boolean             _strip_gene_tree;
+    private final boolean             _strip_species_tree;
+    private int                       _speciation_or_duplication_events_sum;
+    private int                       _speciations_sum;
+    private final List<PhylogenyNode> _stripped_gene_tree_nodes;
+    private final List<PhylogenyNode> _stripped_species_tree_nodes;
+    private final Set<PhylogenyNode>  _mapped_species_tree_nodes;
 
     /**
      * Constructor which sets the gene tree and the species tree to be compared.
@@ -109,7 +108,6 @@ public final class GSDI extends SDI {
         _speciation_or_duplication_events_sum = 0;
         _speciations_sum = 0;
         _most_parsimonious_duplication_model = most_parsimonious_duplication_model;
-        _transversal_counts = new HashMap<PhylogenyNode, Integer>();
         _duplications_sum = 0;
         _strip_gene_tree = strip_gene_tree;
         _strip_species_tree = strip_species_tree;
@@ -118,7 +116,8 @@ public final class GSDI extends SDI {
         _mapped_species_tree_nodes = new HashSet<PhylogenyNode>();
         getSpeciesTree().preOrderReId();
         linkNodesOfG();
-        geneTreePostOrderTraversal( getGeneTree().getRoot() );
+        //geneTreePostOrderTraversal( getGeneTree().getRoot(), null );
+        geneTreePostOrderTraversal();
     }
 
     GSDI( final Phylogeny gene_tree, final Phylogeny species_tree, final boolean most_parsimonious_duplication_model )
@@ -126,24 +125,6 @@ public final class GSDI extends SDI {
         this( gene_tree, species_tree, most_parsimonious_duplication_model, false, false );
     }
 
-    private final Event createDuplicationEvent() {
-        final Event event = Event.createSingleDuplicationEvent();
-        ++_duplications_sum;
-        return event;
-    }
-
-    private final Event createSingleSpeciationOrDuplicationEvent() {
-        final Event event = Event.createSingleSpeciationOrDuplicationEvent();
-        ++_speciation_or_duplication_events_sum;
-        return event;
-    }
-
-    private final Event createSpeciationEvent() {
-        final Event event = Event.createSingleSpeciationEvent();
-        ++_speciations_sum;
-        return event;
-    }
-
     // s is the node on the species tree g maps to.
     private final void determineEvent( final PhylogenyNode s, final PhylogenyNode g ) {
         Event event = null;
@@ -155,74 +136,93 @@ public final class GSDI extends SDI {
                 ++sum_g_childs_mapping_to_s;
             }
         }
-        // Determine the sum of traversals.
-        int traversals_sum = 0;
-        int max_traversals = 0;
-        PhylogenyNode max_traversals_node = null;
-        if ( !s.isExternal() ) {
-            for( int i = 0; i < s.getNumberOfDescendants(); ++i ) {
-                final PhylogenyNode current_node = s.getChildNode( i );
-                final int traversals = getTraversalCount( current_node );
-                traversals_sum += traversals;
-                if ( traversals > max_traversals ) {
-                    max_traversals = traversals;
-                    max_traversals_node = current_node;
-                }
-            }
-        }
-        // System.out.println( " sum=" + traversals_sum );
-        // System.out.println( " max=" + max_traversals );
-        // System.out.println( " m=" + sum_g_childs_mapping_to_s );
-        if ( sum_g_childs_mapping_to_s > 0 ) {
-            if ( traversals_sum == 2 ) {
+        if ( g.getLink().getNumberOfDescendants() == 2 ) {
+            if ( sum_g_childs_mapping_to_s > 0 ) {
                 event = createDuplicationEvent();
-                System.out.print( g.toString() );
-                System.out.println( " : ==2" );
-                //  _transversal_counts.clear();
             }
-            else if ( traversals_sum > 2 ) {
-                if ( max_traversals <= 1 ) {
-                    if ( _most_parsimonious_duplication_model ) {
-                        event = createSpeciationEvent();
-                    }
-                    else {
-                        event = createSingleSpeciationOrDuplicationEvent();
+            else {
+                event = createSpeciationEvent();
+            }
+        }
+        else {
+            if ( sum_g_childs_mapping_to_s > 0 ) {
+                boolean multiple = false;
+                Set<PhylogenyNode> set = new HashSet<PhylogenyNode>();
+                for( PhylogenyNode n : g.getChildNode1().getLink().getAllExternalDescendants() ) {
+                    set.add( n );
+                }
+                for( PhylogenyNode n : g.getChildNode2().getLink().getAllExternalDescendants() ) {
+                    if ( set.contains( n ) ) {
+                        multiple = true;
+                        break;
                     }
+                    // else {
+                    //     set.add( n );
+                    // }
                 }
-                else {
+                if ( multiple ) {
                     event = createDuplicationEvent();
-                    //System.out.println( g.toString() );
-                    _transversal_counts.put( max_traversals_node, 1 );
-                    //  _transversal_counts.clear();
+                }
+                else {
+                    event = createSingleSpeciationOrDuplicationEvent();
                 }
             }
             else {
-                event = createDuplicationEvent();
-                //   _transversal_counts.clear();
+                event = createSpeciationEvent();
             }
-            normalizeTcounts( s );
-        }
-        else {
-            event = createSpeciationEvent();
         }
         g.getNodeData().setEvent( event );
     }
 
-    private void normalizeTcounts( final PhylogenyNode s ) {
-        int min_traversals = Integer.MAX_VALUE;
-        for( int i = 0; i < s.getNumberOfDescendants(); ++i ) {
-            final PhylogenyNode current_node = s.getChildNode( i );
-            final int traversals = getTraversalCount( current_node );
-            if ( traversals < min_traversals ) {
-                min_traversals = traversals;
-            }
-        }
-        for( int i = 0; i < s.getNumberOfDescendants(); ++i ) {
-            final PhylogenyNode current_node = s.getChildNode( i );
-            _transversal_counts.put( current_node, getTraversalCount( current_node ) - min_traversals );
-        }
-    }
-
+    //    private final void determineEvent2( final PhylogenyNode s, final PhylogenyNode g ) {
+    //        Event event = null;
+    //        // Determine how many children map to same node as parent.
+    //        int sum_g_childs_mapping_to_s = 0;
+    //        for( int i = 0; i < g.getNumberOfDescendants(); ++i ) {
+    //            final PhylogenyNode c = g.getChildNode( i );
+    //            if ( c.getLink() == s ) {
+    //                ++sum_g_childs_mapping_to_s;
+    //            }
+    //        }
+    //        // Determine the sum of traversals.
+    //        int traversals_sum = 0;
+    //        int max_traversals = 0;
+    //        PhylogenyNode max_traversals_node = null;
+    //        if ( !s.isExternal() ) {
+    //            for( int i = 0; i < s.getNumberOfDescendants(); ++i ) {
+    //                final PhylogenyNode current_node = s.getChildNode( i );
+    //                final int traversals = getTraversalCount( current_node );
+    //                traversals_sum += traversals;
+    //                if ( traversals > max_traversals ) {
+    //                    max_traversals = traversals;
+    //                    max_traversals_node = current_node;
+    //                }
+    //            }
+    //        }
+    //        // System.out.println( " sum=" + traversals_sum );
+    //        // System.out.println( " max=" + max_traversals );
+    //        // System.out.println( " m=" + sum_g_childs_mapping_to_s );
+    //        if ( s.getNumberOfDescendants() == 2 ) {
+    //            if ( sum_g_childs_mapping_to_s == 0 ) {
+    //                event = createSpeciationEvent();
+    //            }
+    //            else {
+    //                event = createDuplicationEvent();
+    //            }
+    //        }
+    //        else {
+    //            if ( sum_g_childs_mapping_to_s == 2 ) {
+    //                event = createDuplicationEvent();
+    //            }
+    //            else if ( sum_g_childs_mapping_to_s == 1 ) {
+    //                event = createSingleSpeciationOrDuplicationEvent();
+    //            }
+    //            else {
+    //                event = createSpeciationEvent();
+    //            }
+    //        }
+    //        g.getNodeData().setEvent( event );
+    //    }
     /**
      * Traverses the subtree of PhylogenyNode g in postorder, calculating the
      * mapping function M, and determines which nodes represent speciation
@@ -235,71 +235,59 @@ public final class GSDI extends SDI {
      * @param g
      *            starting node of a gene tree - normally the root
      */
-    final void geneTreePostOrderTraversal( final PhylogenyNode g ) {
-        if ( !g.isExternal() ) {
-            boolean all_ext = true;
-            for( int i = 0; i < g.getNumberOfDescendants(); ++i ) {
-                if ( g.getChildNode( i ).isInternal() ) {
-                    all_ext = false;
-                    break;
+    final void geneTreePostOrderTraversal() {
+        for( PhylogenyNodeIterator it = getGeneTree().iteratorPostorder(); it.hasNext(); ) {
+            PhylogenyNode g = it.next();
+            if ( !g.isExternal() ) {
+                final PhylogenyNode[] linked_nodes = new PhylogenyNode[ g.getNumberOfDescendants() ];
+                for( int i = 0; i < linked_nodes.length; ++i ) {
+                    if ( g.getChildNode( i ).getLink() == null ) {
+                        System.out.println( "link is null for " + g.getChildNode( i ) );
+                        System.exit( -1 );
+                    }
+                    linked_nodes[ i ] = g.getChildNode( i ).getLink();
                 }
-            }
-            if ( all_ext ) {
-                //_transversal_counts.clear();
-            }
-            for( int i = 0; i < g.getNumberOfDescendants(); ++i ) {
-                geneTreePostOrderTraversal( g.getChildNode( i ) );
-            }
-            final PhylogenyNode[] linked_nodes = new PhylogenyNode[ g.getNumberOfDescendants() ];
-            for( int i = 0; i < linked_nodes.length; ++i ) {
-                if ( g.getChildNode( i ).getLink() == null ) {
-                    System.out.println( "link is null for " + g.getChildNode( i ) );
-                    System.exit( -1 );
+                final int[] min_max = obtainMinMaxIdIndices( linked_nodes );
+                int min_i = min_max[ 0 ];
+                int max_i = min_max[ 1 ];
+                while ( linked_nodes[ min_i ] != linked_nodes[ max_i ] ) {
+                    linked_nodes[ max_i ] = linked_nodes[ max_i ].getParent();
+                    final int[] min_max_ = obtainMinMaxIdIndices( linked_nodes );
+                    min_i = min_max_[ 0 ];
+                    max_i = min_max_[ 1 ];
                 }
-                linked_nodes[ i ] = g.getChildNode( i ).getLink();
-            }
-            final int[] min_max = obtainMinMaxIdIndices( linked_nodes );
-            int min_i = min_max[ 0 ];
-            int max_i = min_max[ 1 ];
-            // initTransversalCounts();
-            while ( linked_nodes[ min_i ] != linked_nodes[ max_i ] ) {
-                increaseTraversalCount( linked_nodes[ max_i ] );
-                linked_nodes[ max_i ] = linked_nodes[ max_i ].getParent();
-                final int[] min_max_ = obtainMinMaxIdIndices( linked_nodes );
-                min_i = min_max_[ 0 ];
-                max_i = min_max_[ 1 ];
+                final PhylogenyNode s = linked_nodes[ max_i ];
+                g.setLink( s );
+                // Determines whether dup. or spec.
+                determineEvent( s, g );
             }
-            final PhylogenyNode s = linked_nodes[ max_i ];
-            g.setLink( s );
-            // Determines whether dup. or spec.
-            determineEvent( s, g );
         }
     }
 
-    public final int getSpeciationOrDuplicationEventsSum() {
-        return _speciation_or_duplication_events_sum;
+    private final Event createDuplicationEvent() {
+        final Event event = Event.createSingleDuplicationEvent();
+        ++_duplications_sum;
+        return event;
     }
 
-    public final int getSpeciationsSum() {
-        return _speciations_sum;
+    private final Event createSingleSpeciationOrDuplicationEvent() {
+        final Event event = Event.createSingleSpeciationOrDuplicationEvent();
+        ++_speciation_or_duplication_events_sum;
+        return event;
     }
 
-    private final int getTraversalCount( final PhylogenyNode node ) {
-        if ( _transversal_counts.containsKey( node ) ) {
-            return _transversal_counts.get( node );
-        }
-        return 0;
+    private final Event createSpeciationEvent() {
+        final Event event = Event.createSingleSpeciationEvent();
+        ++_speciations_sum;
+        return event;
     }
 
-    private final void increaseTraversalCount( final PhylogenyNode node ) {
-        if ( _transversal_counts.containsKey( node ) ) {
-            _transversal_counts.put( node, _transversal_counts.get( node ) + 1 );
-        }
-        else {
-            _transversal_counts.put( node, 1 );
-        }
-        // System.out.println( "count for node " + node.getID() + " is now "
-        // + getTraversalCount( node ) );
+    public final int getSpeciationOrDuplicationEventsSum() {
+        return _speciation_or_duplication_events_sum;
+    }
+
+    public final int getSpeciationsSum() {
+        return _speciations_sum;
     }
 
     /**