Skip to content

Commit

Permalink
Merge pull request #363 from jakirkham/add_gen_dict_n_comps
Browse files Browse the repository at this point in the history
Provide optional number of components argument for dictionary learning
  • Loading branch information
jakirkham committed Feb 18, 2016
2 parents e983110 + 344d41e commit cb75f10
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
31 changes: 30 additions & 1 deletion nanshe/imp/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,10 @@ def preprocess_data(new_data, out=None, **parameters):

@prof.log_call(trace_logger)
@hdf5.record.static_array_debug_recorder
def generate_dictionary(new_data, initial_dictionary=None, **parameters):
def generate_dictionary(new_data,
initial_dictionary=None,
n_components=None,
**parameters):
"""
Generates a dictionary using the data and parameters given for trainDL.
Expand All @@ -769,6 +772,9 @@ def generate_dictionary(new_data, initial_dictionary=None, **parameters):
initial_dictionary(numpy.ndarray): dictionary to start the
algorithm with.
n_components(int): number of components for the
dictionary to use.
**parameters(dict): passed directly to
spams.trainDL.
Expand All @@ -778,6 +784,29 @@ def generate_dictionary(new_data, initial_dictionary=None, **parameters):

import nanshe.box

# Sync the number of components with the method.
if n_components is None:
if "spams.trainDL" in parameters:
n_components = parameters["spams.trainDL"]["K"]
elif "sklearn.decomposition.dict_learning_online" in parameters:
n_components = parameters["sklearn.decomposition.dict_learning_online"]["n_components"]
else:
assert False, "Unknown algorithm must define `n_components`."
else:
if "spams.trainDL" in parameters:
assert parameters["spams.trainDL"].get("K", n_components) == n_components,\
"If `n_components` and `spams.trainDL[\"K\"]` are defined," \
" they should be defined the same."
parameters["spams.trainDL"]["K"] = n_components
elif "sklearn.decomposition.dict_learning_online" in parameters:
assert parameters["sklearn.decomposition.dict_learning_online"].get("n_components", n_components) == n_components,\
"If `n_components` and " \
"`sklearn.decomposition.dict_learning_online[\"n_components\"]`" \
" are defined, they should be defined the same."
parameters["sklearn.decomposition.dict_learning_online"]["n_components"] = n_components
else:
assert False, "Unknown algorithm cannot use `n_components`."

# Needs to be floating point.
# However, it need not be double precision as there is single precision
# function signature.
Expand Down
12 changes: 6 additions & 6 deletions tests/test_nanshe/test_imp/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,12 +1042,12 @@ def test_generate_dictionary_04(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(numpy.float32),
g.astype(numpy.float32),
len(g),
**{
"spams.trainDL" : {
"gamma2" : 0,
"gamma1" : 0,
"numThreads" : 1,
"K" : len(g),
"iter" : 10,
"modeD" : 0,
"posAlpha" : True,
Expand Down Expand Up @@ -1098,12 +1098,12 @@ def test_generate_dictionary_05(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(float),
g.astype(float),
len(g),
**{
"spams.trainDL" : {
"gamma2" : 0,
"gamma1" : 0,
"numThreads" : 1,
"K" : len(g),
"iter" : 10,
"modeD" : 0,
"posAlpha" : True,
Expand Down Expand Up @@ -1155,12 +1155,12 @@ def test_generate_dictionary_06(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(numpy.float32),
g.astype(numpy.float32),
len(g),
**{
"spams.trainDL" : {
"gamma2" : 0,
"gamma1" : 0,
"numThreads" : 1,
"K" : len(g),
"iter" : 10,
"modeD" : 0,
"posAlpha" : True,
Expand Down Expand Up @@ -1212,12 +1212,12 @@ def test_generate_dictionary_07(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(float),
g.astype(float),
len(g),
**{
"spams.trainDL" : {
"gamma2" : 0,
"gamma1" : 0,
"numThreads" : 1,
"K" : len(g),
"iter" : 10,
"modeD" : 0,
"posAlpha" : True,
Expand Down Expand Up @@ -1359,10 +1359,10 @@ def test_generate_dictionary_10(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(float),
g.astype(float),
len(g),
**{
"sklearn.decomposition.dict_learning_online" : {
"n_jobs" : 1,
"n_components" : len(g),
"n_iter" : 20,
"batch_size" : 256,
"alpha" : 0.2
Expand Down Expand Up @@ -1408,10 +1408,10 @@ def test_generate_dictionary_11(self):
d = nanshe.imp.segment.generate_dictionary(
g.astype(float),
g.astype(float),
len(g),
**{
"sklearn.decomposition.dict_learning_online" : {
"n_jobs" : 1,
"n_components" : len(g),
"n_iter" : 20,
"batch_size" : 256,
"alpha" : 0.2
Expand Down

0 comments on commit cb75f10

Please sign in to comment.