JAL-3855 auto select clusters and fix up colouring/propagation on calculation & tree...
[jalview.git] / src / jalview / datamodel / GroupSet.java
1 package jalview.datamodel;
2
3 import java.awt.Color;
4 import java.util.ArrayList;
5 import java.util.Arrays;
6 import java.util.BitSet;
7 import java.util.HashMap;
8 import java.util.List;
9
10 import jalview.analysis.AverageDistanceEngine;
11 import jalview.bin.Console;
12
13 public class GroupSet implements GroupSetI
14 {
15   List<BitSet> groups = Arrays.asList();
16
17   public GroupSet(GroupSet grps)
18   {
19     abs = grps.abs;
20     colorMap = new HashMap<BitSet, Color>(grps.colorMap);
21     groups = new ArrayList<BitSet>(grps.groups);
22     newick = grps.newick;
23     thresh = grps.thresh;
24     treeType = grps.treeType;
25   }
26
27   public GroupSet()
28   {
29     // TODO Auto-generated constructor stub
30   }
31
32   public GroupSet(boolean abs2, float thresh2, List<BitSet> groups2,
33           String treeType2, String newick2)
34   {
35     abs = abs2;
36     thresh = thresh2;
37     groups = groups2;
38     treeType = treeType2;
39     newick = newick2;
40   }
41
42   @Override
43   public boolean hasGroups()
44   {
45     return groups != null;
46   }
47
48   String newick = null;
49
50   @Override
51   public String getNewick()
52   {
53     return newick;
54   }
55
56   @Override
57   public boolean hasTree()
58   {
59     return newick != null && newick.length() > 0;
60   }
61
62   boolean abs = false;
63
64   double thresh = 0;
65
66   String treeType = null;
67
68   @Override
69   public void updateGroups(List<BitSet> colGroups)
70   {
71     if (colGroups != null)
72     {
73       groups = colGroups;
74     }
75   }
76
77   @Override
78   public BitSet getGroupsFor(int column)
79   {
80     if (groups != null)
81     {
82       for (BitSet gp : groups)
83       {
84         if (gp.get(column))
85         {
86           return gp;
87         }
88       }
89     }
90     // return singleton set;
91     BitSet bs = new BitSet();
92     bs.set(column);
93     return bs;
94   }
95
96   HashMap<BitSet, Color> colorMap = new HashMap<>();
97
98   @Override
99   public Color getColourForGroup(BitSet bs)
100   {
101     if (bs == null)
102     {
103       return Color.white;
104     }
105     Color groupCol = colorMap.get(bs);
106     if (groupCol == null)
107     {
108       return Color.white;
109     }
110     return groupCol;
111   }
112
113   @Override
114   public void setColorForGroup(BitSet bs, Color color)
115   {
116     colorMap.put(bs, color);
117   }
118
119   @Override
120   public void restoreGroups(List<BitSet> newgroups, String treeMethod,
121           String tree, double thresh2)
122   {
123     treeType = treeMethod;
124     groups = newgroups;
125     thresh = thresh2;
126     newick = tree;
127
128   }
129
130   @Override
131   public boolean hasCutHeight()
132   {
133     return groups != null && thresh != 0;
134   }
135
136   @Override
137   public double getCutHeight()
138   {
139     return thresh;
140   }
141
142   @Override
143   public String getTreeMethod()
144   {
145     return treeType;
146   }
147
148   public static GroupSet makeGroups(ContactMatrixI matrix, boolean autoCut)
149   {
150     return makeGroups(matrix, autoCut, 0, autoCut);
151   }
152   public static GroupSet makeGroups(ContactMatrixI matrix, boolean auto, float thresh,
153           boolean abs)
154   {
155     AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null,
156             matrix, true);
157     double height = clusterer.findHeight(clusterer.getTopNode());
158     Console.debug("Column tree height: " + height);
159     String newick = new jalview.io.NewickFile(clusterer.getTopNode(), false,
160             true).print();
161     String treeType = "UPGMA";
162     Console.trace("Newick string\n" + newick);
163
164     List<BinaryNode> nodegroups;
165     float cut = -1f;
166     if (auto)
167     {
168       double rootw = 0;
169       int p = 2;
170       BinaryNode bn = clusterer.getTopNode();
171       while (p-- > 0 & bn.left() != null)
172       {
173         if (bn.left() != null)
174         {
175           bn = bn.left();
176         }
177         if (bn.left() != null)
178         {
179           rootw = bn.height;
180         }
181       }
182       thresh = Math.max((float) (rootw / height) - 0.01f, 0);
183       cut = thresh;
184       nodegroups = clusterer.groupNodes(thresh);
185     }
186     else
187     {
188       if (abs ? (height > thresh) : (0 < thresh && thresh < 1))
189       {
190         cut = abs ? thresh : (float) (thresh * height);
191         Console.debug("Threshold " + cut + " for height=" + height);
192         nodegroups = clusterer.groupNodes(cut);
193       }
194       else
195       {
196         nodegroups = new ArrayList<BinaryNode>();
197         nodegroups.add(clusterer.getTopNode());
198       }
199     }
200     
201     List<BitSet> groups = new ArrayList<>();
202     for (BinaryNode root : nodegroups)
203     {
204       BitSet gpset = new BitSet();
205       for (BinaryNode leaf : clusterer.findLeaves(root))
206       {
207         gpset.set((Integer) leaf.element());
208       }
209       groups.add(gpset);
210     }
211     GroupSet grps = new GroupSet(abs, (cut == -1f) ? thresh : cut, groups,
212             treeType, newick);
213     return grps;
214   }
215
216   @Override
217   public List<BitSet> getGroups()
218   {
219     return groups;
220   }
221 }