Fix core WST file
[jabaws.git] / binaries / src / disembl / biopython-1.50 / Bio / kNN.py
1 #!/usr/bin/env python
2
3 """
4 This module provides code for doing k-nearest-neighbors classification.
5
6 k Nearest Neighbors is a supervised learning algorithm that classifies
7 a new observation based the classes in its surrounding neighborhood.
8
9 Glossary:
10 distance   The distance between two points in the feature space.
11 weight     The importance given to each point for classification. 
12
13
14 Classes:
15 kNN           Holds information for a nearest neighbors classifier.
16
17
18 Functions:
19 train        Train a new kNN classifier.
20 calculate    Calculate the probabilities of each class, given an observation.
21 classify     Classify an observation into a class.
22
23     Weighting Functions:
24 equal_weight    Every example is given a weight of 1.
25
26 """
27
28 #TODO - Remove this work around once we drop python 2.3 support
29 try:
30     set = set
31 except NameError:
32     from sets import Set as set
33
34 import numpy
35
36 class kNN:
37     """Holds information necessary to do nearest neighbors classification.
38
39     Members:
40     classes  Set of the possible classes.
41     xs       List of the neighbors.
42     ys       List of the classes that the neighbors belong to.
43     k        Number of neighbors to look at.
44
45     """
46     def __init__(self):
47         """kNN()"""
48         self.classes = set()
49         self.xs = []
50         self.ys = []
51         self.k = None
52
53 def equal_weight(x, y):
54     """equal_weight(x, y) -> 1"""
55     # everything gets 1 vote
56     return 1
57
58 def train(xs, ys, k, typecode=None):
59     """train(xs, ys, k) -> kNN
60     
61     Train a k nearest neighbors classifier on a training set.  xs is a
62     list of observations and ys is a list of the class assignments.
63     Thus, xs and ys should contain the same number of elements.  k is
64     the number of neighbors that should be examined when doing the
65     classification.
66     
67     """
68     knn = kNN()
69     knn.classes = set(ys)
70     knn.xs = numpy.asarray(xs, typecode)
71     knn.ys = ys
72     knn.k = k
73     return knn
74
75 def calculate(knn, x, weight_fn=equal_weight, distance_fn=None):
76     """calculate(knn, x[, weight_fn][, distance_fn]) -> weight dict
77
78     Calculate the probability for each class.  knn is a kNN object.  x
79     is the observed data.  weight_fn is an optional function that
80     takes x and a training example, and returns a weight.  distance_fn
81     is an optional function that takes two points and returns the
82     distance between them.  If distance_fn is None (the default), the
83     Euclidean distance is used.  Returns a dictionary of the class to
84     the weight given to the class.
85     
86     """
87     x = numpy.asarray(x)
88
89     order = []  # list of (distance, index)
90     if distance_fn:
91         for i in range(len(knn.xs)):
92             dist = distance_fn(x, knn.xs[i])
93             order.append((dist, i))
94     else:
95         # Default: Use a fast implementation of the Euclidean distance
96         temp = numpy.zeros(len(x))
97         # Predefining temp allows reuse of this array, making this
98         # function about twice as fast.
99         for i in range(len(knn.xs)):
100             temp[:] = x - knn.xs[i]
101             dist = numpy.sqrt(numpy.dot(temp,temp))
102             order.append((dist, i))
103     order.sort()
104
105     # first 'k' are the ones I want.
106     weights = {}  # class -> number of votes
107     for k in knn.classes:
108         weights[k] = 0.0
109     for dist, i in order[:knn.k]:
110         klass = knn.ys[i]
111         weights[klass] = weights[klass] + weight_fn(x, knn.xs[i])
112
113     return weights
114
115 def classify(knn, x, weight_fn=equal_weight, distance_fn=None):
116     """classify(knn, x[, weight_fn][, distance_fn]) -> class
117
118     Classify an observation into a class.  If not specified, weight_fn will
119     give all neighbors equal weight.  distance_fn is an optional function
120     that takes two points and returns the distance between them.  If
121     distance_fn is None (the default), the Euclidean distance is used.
122     """
123     weights = calculate(
124         knn, x, weight_fn=weight_fn, distance_fn=distance_fn)
125
126     most_class = None
127     most_weight = None
128     for klass, weight in weights.items():
129         if most_class is None or weight > most_weight:
130             most_class = klass
131             most_weight = weight
132     return most_class