Skip to content

Commit 1e78faa

Browse files
committed
Addressed initial comments, changed to pdist
1 parent 544f6d2 commit 1e78faa

File tree

3 files changed

+35
-27
lines changed

3 files changed

+35
-27
lines changed

metric_learn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .nca import NCA
1010
from .lfda import LFDA
1111
from .rca import RCA, RCA_Supervised
12+
from .mlkr import MLKR

metric_learn/mlkr.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,39 @@
99
from __future__ import division
1010
import numpy as np
1111
from six.moves import xrange
12+
from scipy.spatial.distance import pdist, squareform
1213

1314
from .base_metric import BaseMetricLearner
1415

1516
class MLKR(BaseMetricLearner):
1617
"""Metric Learning for Kernel Regression (MLKR)"""
17-
def __init__(self, A=None, epsilon=0.01):
18+
def __init__(self, A0=None, epsilon=0.01, alpha=0.0001):
1819
"""
1920
MLKR initialization
2021
2122
Parameters
2223
----------
23-
A: Initialization of matrix A. Defaults to the identity matrix.
24-
epsilon: Step size for gradient descent
24+
A0: Initialization of matrix A. Defaults to the identity matrix.
25+
epsilon: Step size for gradient descent.
26+
alpha: Stopping criterion for loss function in gradient descent.
2527
"""
2628
self.params = {
27-
"A": A,
28-
"epsilon": epsilon
29+
"A0": A0,
30+
"epsilon": epsilon,
31+
"alpha": alpha
2932
}
3033

3134
def _process_inputs(self, X, y):
35+
self.X = np.array(X, copy=False)
36+
y = np.array(y, copy=False)
3237
if X.ndim == 1:
3338
X = X[:, np.newaxis]
3439
if y.ndim == 1:
35-
y == y[:, np.newaxis]
36-
self.X = X
40+
y = y[:, np.newaxis]
3741
n, d = X.shape
38-
assert y.shape[0] == n
42+
if y.shape[0] != n:
43+
raise ValueError('Data and label lengths mismatch: %d != %d'
44+
% (n, y.shape[0]))
3945
return y, n, d
4046

4147
def fit(self, X, y):
@@ -52,54 +58,48 @@ def fit(self, X, y):
5258
self: Instance of self
5359
"""
5460
y, n, d = self._process_inputs(X, y)
55-
alpha = 0.0001 # Stopping criterion
56-
if self.params['A'] is None:
61+
if self.params['A0'] is None:
5762
A = np.identity(d) # Initialize A as eye matrix
5863
else:
59-
A = self.params['A']
64+
A = self.params['A0']
6065
assert A.shape == (d, d)
6166
cost = np.Inf
6267
# Gradient descent procedure
63-
while cost > alpha:
64-
K = self._computeK(X, A, n)
68+
while cost > self.params['alpha']:
69+
K = self._computeK(X, A)
6570
yhat = self._computeyhat(y, K)
6671
sum_i = 0
6772
for i in xrange(n):
6873
sum_j = 0
6974
for j in xrange(n):
70-
sum_j += (yhat[j] - y[j]) * K[i][j] * \
71-
(X[i, :] - X[j, :])[:, np.newaxis].dot \
72-
((X[i, :] - X[j, :])[:, np.newaxis].T)
75+
diffK = (yhat[j] - y[j]) * K[i, j]
76+
x_ij = (X[i, :] - X[j, :])[:, np.newaxis]
77+
x_ijT = x_ij.T
78+
sum_j += diffK * x_ij.dot(x_ijT)
7379
sum_i += (yhat[i] - y[i]) * sum_j
7480
gradient = 4 * A.dot(sum_i)
7581
A -= self.params['epsilon'] * gradient
7682
cost = np.sum(np.square(yhat - y))
7783
self._transformer = A
7884
return self
7985

80-
def _computeK(self, X, A, n):
86+
def _computeK(self, X, A):
8187
"""
8288
Internal helper function to compute K matrix.
8389
8490
Parameters:
8591
----------
8692
X: (n x d) array of samples
8793
A: (d x d) 'A' matrix
88-
n: number of rows in X
8994
9095
Returns:
9196
-------
9297
K: (n x n) K matrix where Kij = exp(-distance(x_i, x_j)) where
9398
distance is defined as squared L2 norm of (x_i - x_j)
9499
"""
95-
dist_mat = np.zeros(shape=(n, n))
96-
for i in xrange(n):
97-
for j in xrange(n):
98-
if i == j:
99-
dist = 0
100-
else:
101-
dist = np.sum(np.square(A.dot((X[i, :] - X[j, :]))))
102-
dist_mat[i, j] = dist
100+
dist_mat = pdist(X, metric='mahalanobis', VI=A.T.dot(A))
101+
dist_mat = np.square(dist_mat)
102+
dist_mat = squareform(dist_mat)
103103
return np.exp(-dist_mat)
104104

105105
def _computeyhat(self, y, K):

test/metric_learn_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from metric_learn import (
99
LMNN, NCA, LFDA, Covariance,
10-
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
10+
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MLKR)
1111
# Import this specially for testing.
1212
from metric_learn.lmnn import python_LMNN
1313

@@ -113,6 +113,13 @@ def test_iris(self):
113113
csep = class_separation(rca.transform(), self.iris_labels)
114114
self.assertLess(csep, 0.25)
115115

116+
class TestMLKR(MetricTestCase):
117+
def test_iris(self):
118+
mlkr = MLKR(epsilon=10, alpha=10) # for faster testing
119+
mlkr.fit(self.iris_points, self.iris_labels)
120+
csep = class_separation(mlkr.transform(), self.iris_labels)
121+
self.assertLess(csep, 0.25)
122+
116123

117124
if __name__ == '__main__':
118125
unittest.main()

0 commit comments

Comments
 (0)