Skip to content

Commit

Permalink
Merge branch 'master' into memory_usage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhi-bao authored Jan 3, 2025
2 parents a74d40c + 6d3a3b7 commit 9794ced
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# the repo.
# Reviewers list below will be requested for
# review when someone opens a pull request.
* @cjlin1 @sian-chen @Eleven1Liu @henryyang42 @JamesLYC88 @Gordon119
* @cjlin1 @ntumlgroup/libmultilabel_reviewers
2 changes: 1 addition & 1 deletion docs/cli/flags.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ or directly passed as flags. If an option exists in both the config
file and flags, flags take precedent and override the config file.

The config file is a yaml file, examples may be found in
`example_config <https://github.com/ASUS-AICS/LibMultiLabel/tree/master/example_config>`_.
`example_config <https://github.com/ntumlgroup/LibMultiLabel/tree/master/example_config>`_.
In the config file, each key-value pair ``key: value`` corresponds to
passing the flag ``--key value``. The following example sets the training data path
in the config file
Expand Down
4 changes: 2 additions & 2 deletions docs/cli/ov_data_format.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ and then create a virtual enviroment as follows.
conda create -n LibMultiLabel python=3.8
conda activate LibMultiLabel
* Clone `LibMultiLabel <https://github.com/ASUS-AICS/LibMultiLabel>`_.
* Clone `LibMultiLabel <https://github.com/ntumlgroup/LibMultiLabel>`_.

.. code-block:: bash
git clone https://github.com/ASUS-AICS/LibMultiLabel.git
git clone https://github.com/ntumlgroup/LibMultiLabel.git
cd LibMultiLabel
* Install the default dependencies with:
Expand Down
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ For practical use, please see the `Tutorials <tutorial.html>`_. For Implementati
library_index
tutorial
Implementation Document <https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/libmultilabel_implementation.pdf>
papers


..
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/Parameter_Selection_for_Neural_Networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Direct Trying Some Parameters
-----------------------------

First, train a BiGRU model with the
`default configuration file <https://github.com/ASUS-AICS/LibMultiLabel/blob/master/example_config/EUR-Lex/bigru_lwan.yml>`_
`default configuration file <https://github.com/ntumlgroup/LibMultiLabel/blob/master/example_config/EUR-Lex/bigru_lwan.yml>`_
with a little modification on the learning rate.
Some important parameters are listed as follows.

Expand Down Expand Up @@ -92,7 +92,7 @@ To save time, LibMultiLabel has incorporated some early stopping techniques impl
Here we demonstrate an example of applying an `ASHA (Asynchronous Successive Halving Algorithm) Scheduler <https://arxiv.org/abs/1810.05934>`_.
First, uncomment the following lines in the
`configuration file <https://github.com/ASUS-AICS/LibMultiLabel/blob/master/example_config/EUR-Lex/bigru_lwan_tune.yml>`_:
`configuration file <https://github.com/ntumlgroup/LibMultiLabel/blob/master/example_config/EUR-Lex/bigru_lwan_tune.yml>`_:
.. code-block:: bash
Expand Down
35 changes: 20 additions & 15 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
"""
self.label_map = label_map
self.children = children
self.is_root = False

def isLeaf(self) -> bool:
return len(self.children) == 0
Expand Down Expand Up @@ -58,7 +59,7 @@ def predict_values(
x: sparse.csr_matrix,
beam_width: int = 10,
) -> np.ndarray:
"""Calculates the decision values associated with x.
"""Calculates the probability estimates associated with x.
Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
Expand All @@ -72,10 +73,10 @@ def predict_values(
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])

def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarray:
"""Predict with beam search using cached decision values for a single instance.
"""Predict with beam search using cached probability estimates for a single instance.
Args:
instance_preds (np.ndarray): A vector of cached decision values of each node, has dimension number of labels + total number of metalabels.
instance_preds (np.ndarray): A vector of cached probability estimates of each node, has dimension number of labels + total number of metalabels.
beam_width (int): Number of candidates considered.
Returns:
Expand All @@ -94,18 +95,18 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
continue
slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]]
pred = instance_preds[slice]
children_score = score - np.maximum(0, 1 - pred) ** 2
children_score = score - np.square(np.maximum(0, 1 - pred))
next_level.extend(zip(node.children, children_score.tolist()))

cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
next_level = []

num_labels = len(self.root.label_map)
scores = np.full(num_labels, -np.inf)
scores = np.full(num_labels, 0.0)
for node, score in cur_level:
slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]]
pred = instance_preds[slice]
scores[node.label_map] = np.exp(score - np.maximum(0, 1 - pred) ** 2)
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
return scores


Expand Down Expand Up @@ -134,6 +135,7 @@ def train_tree(
label_representation = (y.T * x).tocsr()
label_representation = sklearn.preprocessing.normalize(label_representation, norm="l2", axis=1)
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
root.is_root = True

num_nodes = 0
# Both type(x) and type(y) are sparse.csr_matrix
Expand All @@ -149,20 +151,23 @@ def count(node):
root.dfs(count)

model_size = get_estimated_model_size(root)
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
print(f"The estimated tree model size is: {model_size / (1024**3):.3f} GB")

# Calculate the total memory (excluding swap) on the local machine
total_memory = psutil.virtual_memory().total
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')
total_memory = psutil.virtual_memory().total
print(f"Your system memory is: {total_memory / (1024**3):.3f} GB")

if (total_memory <= model_size):
raise MemoryError(f'Not enough memory to train the model.')
if total_memory <= model_size:
raise MemoryError(f"Not enough memory to train the model.")

pbar = tqdm(total=num_nodes, disable=not verbose)

def visit(node):
relevant_instances = y[:, node.label_map].getnnz(axis=1) > 0
_train_node(y[relevant_instances], x[relevant_instances], options, node)
if node.is_root:
_train_node(y, x, options, node)
else:
relevant_instances = y[:, node.label_map].getnnz(axis=1) > 0
_train_node(y[relevant_instances], x[relevant_instances], options, node)
pbar.update()

root.dfs(visit)
Expand Down Expand Up @@ -216,7 +221,7 @@ def get_estimated_model_size(root):

def collect_stat(node: Node):
nonlocal total_num_weights

if node.isLeaf():
total_num_weights += len(node.label_map) * node.num_features_used
else:
Expand All @@ -226,7 +231,7 @@ def collect_stat(node: Node):

# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
# Our study showed that among the used features of every binary classification problem, on average no more than 2/3 of weights obtained by the dual coordinate descent method are non-zeros.
return total_num_weights * 16 * 2/3
return total_num_weights * 16 * 2 / 3


def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ numba
pandas>1.3.0
PyYAML
scikit-learn
scipy
scipy<1.14.0
tqdm
psutil
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = libmultilabel
version = 0.7.1
version = 0.7.2
author = LibMultiLabel Team
license = MIT License
license_file = LICENSE
Expand Down Expand Up @@ -30,7 +30,7 @@ install_requires =
pandas>1.3.0
PyYAML
scikit-learn
scipy
scipy<1.14.0
tqdm

python_requires = >=3.8
Expand Down

0 comments on commit 9794ced

Please sign in to comment.