Skip to content

Commit 54b06ae

Browse files
devashishd12perimosocordiae
authored andcommitted
Added MLKR algorithm (#28)
* Added MLKR algorithm * Addressed initial comments, changed to pdist * addressed 2nd review * Made changes in computeyhat
1 parent c5087d7 commit 54b06ae

File tree

3 files changed

+141
-1
lines changed

3 files changed

+141
-1
lines changed

metric_learn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .nca import NCA
1010
from .lfda import LFDA
1111
from .rca import RCA, RCA_Supervised
12+
from .mlkr import MLKR

metric_learn/mlkr.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Metric Learning for Kernel Regression (MLKR), Weinberger et al.,
3+
4+
MLKR is an algorithm for supervised metric learning, which learns a distance
5+
function by directly minimising the leave-one-out regression error. This
6+
algorithm can also be viewed as a supervised variation of PCA and can be used
7+
for dimensionality reduction and high dimensional data visualization.
8+
"""
9+
from __future__ import division
10+
import numpy as np
11+
from six.moves import xrange
12+
from scipy.spatial.distance import pdist, squareform
13+
14+
from .base_metric import BaseMetricLearner
15+
16+
class MLKR(BaseMetricLearner):
17+
"""Metric Learning for Kernel Regression (MLKR)"""
18+
def __init__(self, A0=None, epsilon=0.01, alpha=0.0001):
19+
"""
20+
MLKR initialization
21+
22+
Parameters
23+
----------
24+
A0: Initialization of matrix A. Defaults to the identity matrix.
25+
epsilon: Step size for gradient descent.
26+
alpha: Stopping criterion for loss function in gradient descent.
27+
"""
28+
self.params = {
29+
"A0": A0,
30+
"epsilon": epsilon,
31+
"alpha": alpha
32+
}
33+
34+
def _process_inputs(self, X, y):
35+
self.X = np.array(X, copy=False)
36+
y = np.array(y, copy=False)
37+
if X.ndim == 1:
38+
X = X[:, np.newaxis]
39+
if y.ndim == 1:
40+
y = y[:, np.newaxis]
41+
n, d = X.shape
42+
if y.shape[0] != n:
43+
raise ValueError('Data and label lengths mismatch: %d != %d'
44+
% (n, y.shape[0]))
45+
return y, n, d
46+
47+
def fit(self, X, y):
48+
"""
49+
Fit MLKR model
50+
51+
Parameters:
52+
----------
53+
X : (n x d) array of samples
54+
y : (n) data labels
55+
56+
Returns:
57+
-------
58+
self: Instance of self
59+
"""
60+
y, n, d = self._process_inputs(X, y)
61+
if self.params['A0'] is None:
62+
A = np.identity(d) # Initialize A as eye matrix
63+
else:
64+
A = self.params['A0']
65+
if A.shape != (d, d):
66+
raise ValueError('A0 should be a square matrix of dimension'
67+
' %d. %s shape was provided' % (d, A.shape))
68+
cost = np.Inf
69+
# Gradient descent procedure
70+
alpha = self.params['alpha']
71+
epsilon = self.params['epsilon']
72+
while cost > alpha:
73+
K = self._computeK(X, A)
74+
yhat = self._computeyhat(y, K)
75+
cost = np.sum(np.square(yhat - y))
76+
# Compute gradient
77+
sum_i = 0
78+
for i in xrange(n):
79+
sum_j = 0
80+
for j in xrange(n):
81+
diffK = (yhat[j] - y[j]) * K[i, j]
82+
x_ij = (X[i, :] - X[j, :])[:, np.newaxis]
83+
x_ijT = x_ij.T
84+
sum_j += diffK * x_ij.dot(x_ijT)
85+
sum_i += (yhat[i] - y[i]) * sum_j
86+
gradient = 4 * A.dot(sum_i)
87+
A -= epsilon * gradient
88+
self._transformer = A
89+
return self
90+
91+
@staticmethod
92+
def _computeK(X, A):
93+
"""
94+
Internal helper function to compute K matrix.
95+
96+
Parameters:
97+
----------
98+
X: (n x d) array of samples
99+
A: (d x d) 'A' matrix
100+
101+
Returns:
102+
-------
103+
K: (n x n) K matrix where Kij = exp(-distance(x_i, x_j)) where
104+
distance is defined as squared L2 norm of (x_i - x_j)
105+
"""
106+
dist_mat = pdist(X, metric='mahalanobis', VI=A.T.dot(A))
107+
return np.exp(squareform(-(dist_mat ** 2)))
108+
109+
@staticmethod
110+
def _computeyhat(y, K):
111+
"""
112+
Internal helper function to compute yhat matrix.
113+
114+
Parameters:
115+
----------
116+
y: (n) data labels
117+
K: (n x n) K matrix
118+
119+
Returns:
120+
-------
121+
yhat: (n x 1) yhat matrix
122+
"""
123+
K_mod = np.copy(K)
124+
np.fill_diagonal(K_mod, 0)
125+
numerator = K_mod.dot(y)
126+
denominator = np.sum(K_mod, 1)[:, np.newaxis]
127+
denominator[denominator == 0] = 2.2204e-16 # eps val in octave
128+
yhat = numerator / denominator
129+
return yhat
130+
131+
def transformer(self):
132+
return self._transformer

test/metric_learn_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from metric_learn import (
99
LMNN, NCA, LFDA, Covariance,
10-
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
10+
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MLKR)
1111
# Import this specially for testing.
1212
from metric_learn.lmnn import python_LMNN
1313

@@ -113,6 +113,13 @@ def test_iris(self):
113113
csep = class_separation(rca.transform(), self.iris_labels)
114114
self.assertLess(csep, 0.25)
115115

116+
class TestMLKR(MetricTestCase):
117+
def test_iris(self):
118+
mlkr = MLKR(epsilon=10, alpha=10) # for faster testing
119+
mlkr.fit(self.iris_points, self.iris_labels)
120+
csep = class_separation(mlkr.transform(), self.iris_labels)
121+
self.assertLess(csep, 0.25)
122+
116123

117124
if __name__ == '__main__':
118125
unittest.main()

0 commit comments

Comments
 (0)