forked from tensorly/tensorly
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8f95f28
commit cf0d771
Showing
11 changed files
with
309 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2016 The tensorly Developers. | ||
All rights reserved. | ||
|
||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
a. Redistributions of source code must retain the above copyright notice, | ||
this list of conditions and the following disclaimer. | ||
b. Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
c. Neither the name of the tensorly Developers nor the names of | ||
its contributors may be used to endorse or promote products | ||
derived from this software without specific prior written | ||
permission. | ||
|
||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR | ||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT | ||
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY | ||
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH | ||
DAMAGE. | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from ._khatri_rao import khatri_rao | ||
|
||
|
||
def higher_order_moment(matrix, order=3): | ||
"""Higher order moment of a matrix of observations | ||
Computes the `order`-order moment of `matrix`` | ||
Each row of `matrix` represents a samples | ||
(i.e. an observation) | ||
Parameters | ||
---------- | ||
matrix : 2D-array | ||
array of shape (n_samples, n_features) | ||
i.e. each row is a sample | ||
order : int, optional | ||
order of the moment to compute | ||
""" | ||
matrix = matrix - matrix.mean(axis=0) | ||
n_features = matrix.shape[-1] | ||
t = khatri_rao([matrix.T] * order).mean(axis=1) | ||
return t.reshape([n_features] * order) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import numpy as np | ||
from scipy.linalg import svd | ||
|
||
# Author: Jean Kossaifi | ||
|
||
|
||
def soft_thresholding(tensor, threshold): | ||
"""Soft thresholding operator | ||
sign(tensor)*max(abs(tensor) - threshold, 0) | ||
Parameters | ||
---------- | ||
tensor : ndarray | ||
threshold : float or ndarray with shape tensor.shape | ||
* If float the threshold is applied to the whole tensor | ||
* If ndarray, one theshold is applied per elements, 0 values are ignored | ||
Returns | ||
------- | ||
ndarray | ||
thresholded tensor on which the operator has been applied | ||
Examples | ||
-------- | ||
Basic shrinkage | ||
>>> import numpy as np | ||
>>> from tensorlib.linalg import soft_thresholding | ||
>>> tensor = np.array([[1, -2, 1.5], [-4, 3, -0.5]]) | ||
>>> soft_thresholding(tensor, 1.1) | ||
array([[ 0. , -0.9, 0.4], | ||
[-2.9, 1.9, 0. ]]) | ||
Example with missing values | ||
>>> mask = np.array([[0, 0, 1], [1, 0, 1]]) | ||
>>> soft_thresholding(tensor, mask*1.1) | ||
array([[ 1. , -2. , 0.4], | ||
[-2.9, 3. , 0. ]]) | ||
""" | ||
signs = np.sign(tensor) | ||
values = (signs*tensor - threshold) | ||
return np.where(values > 0, signs*values, 0) | ||
|
||
|
||
def inplace_soft_thresholding(tensor, threshold): | ||
"""Inplace version of the shrinkage operator | ||
Parameters | ||
---------- | ||
tensor : ndarray | ||
threshold : float | ||
Returns | ||
------- | ||
ndarray | ||
tensor on which the operator has been applied inplace | ||
Notes | ||
----- | ||
This version is memory efficient. | ||
For a faster but less memory efficient version, you can use this function: | ||
>>> def soft_thresholding(tensor, threshold): | ||
... return np.maximum(0, tensor - threshold) - np.maximum(0, -tensor - threshold) | ||
""" | ||
index_shrink = ((tensor <= threshold) & (tensor >= -threshold)) | ||
index_more = (tensor > threshold) | ||
index_less = (tensor < -threshold) | ||
tensor[index_shrink] = 0 | ||
tensor[index_more] -= threshold | ||
tensor[index_less] += threshold | ||
return tensor | ||
|
||
|
||
def svd_thresholing(matrix, threshold): | ||
"""Singular value thresholding operator | ||
Parameters | ||
---------- | ||
matrix : ndarray | ||
threshold : float | ||
Returns | ||
------- | ||
ndarray | ||
matrix on which the operator has been applied | ||
""" | ||
U, s, V = svd(matrix, full_matrices=False) | ||
return np.dot(U, soft_thresholding(s, threshold)[:, None]*V) | ||
|
||
|
||
def procrustes(matrix): | ||
"""Procrustes operator | ||
Parameters | ||
---------- | ||
matrix : ndarray | ||
Returns | ||
------- | ||
ndarray | ||
matrix on which the Procrustes operator has been applied | ||
has the same shape as the original tensor | ||
""" | ||
U, _, V = svd(matrix, full_matrices=False) | ||
return np.dot(U, V) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_almost_equal | ||
from ..higher_order_moment import higher_order_moment | ||
|
||
|
||
def test_higher_order_moment(): | ||
"""Test for higher_order_moment""" | ||
X = np.array([[1, 0],[0, 1], [1, 1], [-1, 1], [1, -1.5]]) | ||
centered_X = X - X.mean(axis=0) | ||
order_2 = centered_X.T.dot(centered_X)/X.shape[0] | ||
order_3 = np.array([[[-0.432, 0.196], | ||
[ 0.196, 0.282]], | ||
[[ 0.196, 0.282], | ||
[ 0.282, -0.966]]]) | ||
order_4 = np.array([[[[ 0.8512, -0.4536], | ||
[-0.4536, 0.4828]], | ||
[[-0.4536, 0.4828], | ||
[ 0.4828, -0.7854]]], | ||
[[[-0.4536, 0.4828], | ||
[ 0.4828, -0.7854]], | ||
[[ 0.4828, -0.7854], | ||
[-0.7854, 2.2452]]]]) | ||
assert_array_almost_equal(higher_order_moment(X, 2), order_2) | ||
assert_array_almost_equal(higher_order_moment(X, 3), order_3) | ||
assert_array_almost_equal(higher_order_moment(X, 4), order_4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_equal, assert_array_almost_equal | ||
|
||
from ..proximal import svd_thresholing, soft_thresholding, inplace_soft_thresholding | ||
from ..proximal import procrustes | ||
|
||
# Author: Jean Kossaifi | ||
|
||
|
||
def test_soft_thresholding(): | ||
"""Test for shrinkage""" | ||
|
||
# small test | ||
tensor = np.array([[[1, 2, 3], [4.3, -1.2, 3]], | ||
[[0.5, -5, -1.3], [1.2, 3.7, -9]], | ||
[[-2, 0, 1.0], [0.5, -0.5, 1.1]]], dtype=np.float64) | ||
threshold = 1.1 | ||
copy_tensor = np.copy(tensor) | ||
res = soft_thresholding(tensor, threshold) | ||
true_res = np.array([[[0, 0.9, 1.9], [3.2, -0.1, 1.9]], | ||
[[0, -3.9, -0.2], [0.1, 2.6, -7.9]], | ||
[[-0.9, 0, 0], [0, -0, 0]]], dtype=np.float64) | ||
# account for floating point errors: np array have a precision of around 2e-15 | ||
# check np.finfo(np.float64).eps | ||
assert_array_almost_equal(true_res, res, decimal=15) | ||
# Check that we did not change the original tensor | ||
assert_array_equal(copy_tensor, tensor) | ||
|
||
# Another test | ||
tensor = np.array([[1, 2, 1.5], [4, -6, -0.5], [0.2, 1.02, -3.4]]) | ||
copy_tensor = np.copy(tensor) | ||
threshold = 1.1 | ||
true_res = np.array([[0, 0.9, 0.4], [2.9, -4.9, 0], [0, 0, -2.3]]) | ||
res = soft_thresholding(tensor, threshold) | ||
assert_array_almost_equal(true_res, res, decimal=15) | ||
assert_array_equal(copy_tensor, tensor) | ||
|
||
# Test with missing values | ||
tensor = np.array([[1, 2, 1.5], [4, -6, -0.5], [0.2, 1.02, -3.4]]) | ||
copy_tensor = np.copy(tensor) | ||
mask = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]) | ||
threshold = 1.1*mask | ||
true_res = np.array([[1, 0.9, 0.4], [2.9, -6, 0], [0, 0, -3.4]]) | ||
res = soft_thresholding(tensor, threshold) | ||
assert_array_almost_equal(true_res, res, decimal=15) | ||
assert_array_equal(copy_tensor, tensor) | ||
|
||
|
||
def test_inplace_soft_thresholding(): | ||
"""Test for inplace_shrinkage | ||
Notes | ||
----- | ||
Assumes that shrinkage is tested and works as expected | ||
""" | ||
shape = (4, 5, 3, 2) | ||
tensor = np.random.random(shape) | ||
threshold = 0.21 | ||
true_res = soft_thresholding(tensor, threshold) | ||
res = inplace_soft_thresholding(tensor, threshold) | ||
assert_array_almost_equal(true_res, res) | ||
|
||
# Check that the soft thresholding was done inplace | ||
assert (res is tensor) | ||
|
||
|
||
def test_svd_thresholing(): | ||
"""Test for singular_value_thresholding operator""" | ||
U = np.array([[1, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1]]) | ||
singular_values = [0.4, 2.1, -2] | ||
tensor = U.dot(np.diag(singular_values).dot(U.T)) | ||
shrinked_singular_values = [0, 1.6, -1.5] | ||
true_res = U.dot(np.diag(shrinked_singular_values).dot(U.T)) | ||
res = svd_thresholing(tensor, 0.5) | ||
assert_array_almost_equal(true_res, res) | ||
|
||
|
||
def test_procrustes(): | ||
"""Test for procrustes operator""" | ||
U = np.random.rand(20, 10) | ||
S, _, V = np.linalg.svd(U, full_matrices=False) | ||
true_res = S.dot(V) | ||
res = procrustes(U) | ||
assert_array_almost_equal(true_res, res) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
""" | ||
tensorly version | ||
""" | ||
|
||
__version__ = '0.1.0' |