forked from tensorly/tensorly
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c042fad
commit 8f95f28
Showing
2 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
from numpy.testing import assert_ | ||
from ..tucker_regression import TuckerRegressor | ||
from ...base import tensor_to_vec, partial_tensor_to_vec | ||
from ...metrics.regression import RMSE | ||
|
||
|
||
def test_TuckerRegressor(): | ||
"""Test for TuckerRegressor""" | ||
|
||
# Parameter of the experiment | ||
image_height = 10 | ||
image_width = 10 | ||
n_channels = 3 | ||
ranks = [5, 5, 2] | ||
tol = 10e-3 | ||
|
||
# Generate random samples | ||
X = np.random.normal(size=(1200, image_height, image_width, n_channels), loc=0, scale=1) | ||
regression_weights = np.zeros((image_height, image_width, n_channels)) | ||
regression_weights[2:-2, 2:-2, 0] = 1 | ||
regression_weights[2:-2, 2:-2, 1] = 2 | ||
regression_weights[2:-2, 2:-2, 2] = -1 | ||
|
||
y = partial_tensor_to_vec(X, skip_begin=1).dot(tensor_to_vec(regression_weights)) | ||
X_train = X[:1000, ...] | ||
X_test = X[1000:, ...] | ||
y_train = y[:1000] | ||
y_test = y[1000:] | ||
|
||
estimator = TuckerRegressor(weight_ranks=ranks, tol=10e-8, reg_W=1, n_iter_max=200, verbose=True) | ||
estimator.fit(X_train, y_train) | ||
y_pred = estimator.predict(X_test) | ||
error = RMSE(y_test, y_pred) | ||
assert_(error <= tol, msg='Tucker Regression : RMSE={} > {}'.format(error, tol)) | ||
|
||
params = estimator.get_params() | ||
assert_(params['weight_ranks'] == [5, 5, 2], msg='get_params did not return the correct parameters') | ||
params['weight_ranks'] = [5, 5, 5] | ||
estimator.set_params(**params) | ||
assert_(estimator.weight_ranks == [5, 5, 5], msg='set_params did not correctly set the given parameters') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from scipy.linalg import solve | ||
import numpy as np | ||
from ..base import unfold, vec_to_tensor | ||
from ..base import partial_tensor_to_vec, partial_unfold | ||
from ..tenalg import norm, kronecker | ||
from ..tucker import tucker_to_tensor, tucker_to_vec | ||
from ..utils import check_random_state | ||
|
||
# Author: Jean Kossaifi <[email protected]> | ||
|
||
|
||
class TuckerRegressor(): | ||
def __init__(self, weight_ranks, tol=10e-7, reg_W=1, n_iter_max=100, random_state=None, verbose=1): | ||
"""Tucker tensor regression | ||
Learns a low rank Tucker weight for the regression | ||
Parameters | ||
---------- | ||
weight_ranks : int list | ||
dimension of each mode of the core Tucker weight | ||
tol : float | ||
convergence value | ||
reg_W : int, optional, default is 1 | ||
regularisation on the weights | ||
n_iter_max : int, optional, default is 100 | ||
maximum number of iteration | ||
random_state : None, int or RandomState, optional, default is None | ||
verbose : int, default is 1 | ||
level of verbosity | ||
""" | ||
self.weight_ranks = weight_ranks | ||
self.tol = tol | ||
self.reg_W = reg_W | ||
self.n_iter_max = n_iter_max | ||
self.random_state = random_state | ||
self.verbose = verbose | ||
|
||
def get_params(self): | ||
"""Returns a dictionary of parameters | ||
""" | ||
params = ['weight_ranks', 'tol', 'reg_W', 'n_iter_max', 'random_state', 'verbose'] | ||
return {param_name: getattr(self, param_name) for param_name in params} | ||
|
||
def set_params(self, **parameters): | ||
"""Sets the value of the provided parameters""" | ||
for parameter, value in parameters.items(): | ||
setattr(self, parameter, value) | ||
return self | ||
|
||
def fit(self, X, y): | ||
"""Fits the model to the data (X, y) | ||
Parameters | ||
---------- | ||
X : ndarray of shape (n_samples, N1, ..., NS) | ||
tensor data | ||
y : array of shape (n_samples) | ||
labels associated with each sample | ||
Returns | ||
------- | ||
self | ||
""" | ||
rng = check_random_state(self.random_state) | ||
|
||
# Initialise randomly the weights | ||
G = rng.randn(*self.weight_ranks) | ||
W = [] | ||
for i in range(1, X.ndim): # First dimension of X = number of samples | ||
W.append(np.random.randn(X.shape[i], G.shape[i - 1])) | ||
|
||
# Norm of the weight tensor at each iteration | ||
norm_W = [] | ||
|
||
for iteration in range(self.n_iter_max): | ||
|
||
# Optimise modes of W | ||
for i in range(len(W)): | ||
phi = partial_tensor_to_vec( | ||
np.dot(partial_unfold(X, i), | ||
np.dot(kronecker(W, skip_matrix=i), | ||
unfold(G, i).T))) | ||
# Regress phi on y: we could call a package here, e.g. scikit-learn | ||
inv_term = np.dot(phi.T, phi) + self.reg_W * np.eye(phi.shape[1]) | ||
W_i = vec_to_tensor(solve(inv_term, phi.T.dot(y)), | ||
(X.shape[i + 1], G.shape[i])) | ||
W[i] = W_i | ||
|
||
phi = partial_tensor_to_vec(X).dot(kronecker(W)) | ||
G = vec_to_tensor(solve(phi.T.dot(phi) + self.reg_W * np.eye(phi.shape[1]), phi.T.dot(y)), G.shape) | ||
|
||
weight_tensor_ = tucker_to_tensor(G, W) | ||
norm_W.append(norm(weight_tensor_, 2)) | ||
|
||
# Convergence check | ||
if iteration > 1: | ||
weight_evolution = abs(norm_W[-1] - norm_W[-2]) / norm_W[-1] | ||
|
||
if (weight_evolution <= self.tol): | ||
if self.verbose: | ||
print('\nConverged in {} iterations'.format(iteration)) | ||
break | ||
|
||
self.weight_tensor_ = weight_tensor_ | ||
self.tucker_W_ = (G, W) | ||
self.vec_W_ = tucker_to_vec(G, W) | ||
self.n_iterations_ = iteration + 1 | ||
self.norm_W_ = norm_W | ||
|
||
return self | ||
|
||
def predict(self, X): | ||
"""Returns the predicted labels for a new data tensor | ||
Parameters | ||
---------- | ||
X : ndarray | ||
tensor data of shape (n_samples, N1, ..., NS) | ||
""" | ||
return np.dot(partial_tensor_to_vec(X), self.vec_W_) |