Skip to content

Commit

Permalink
Feat/L2 Regularization (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 18, 2025
1 parent 79c105e commit f11c30d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ unzip *.zip

RUSTDOCFLAGS="-D warnings" cargo doc --release

cargo install cargo-llvm-cov --locked
cargo install cargo-llvm-cov@0.6.15 --locked
SKIP_TRAINING=1 cargo llvm-cov --release
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "2.0.1"
version = "2.0.2"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
84 changes: 80 additions & 4 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ use log::info;

use std::sync::{Arc, Mutex};

static PARAMS_STDDEV: [f32; 19] = [
6.61, 9.52, 17.69, 27.74, 0.55, 0.28, 0.67, 0.12, 0.4, 0.18, 0.34, 0.27, 0.08, 0.14, 0.57,
0.25, 1.03, 0.27, 0.39,
];

pub struct BCELoss<B: Backend> {
backend: PhantomData<B>,
}
Expand Down Expand Up @@ -70,6 +75,21 @@ impl<B: Backend> Model<B> {
let retention = self.power_forgetting_curve(delta_ts, state.stability);
BCELoss::new().forward(retention, labels.float(), weights, reduce)
}

pub(crate) fn l2_regularization(
&self,
init_w: Tensor<B, 1>,
params_stddev: Tensor<B, 1>,
batch_size: usize,
total_size: usize,
gamma: f64,
) -> Tensor<B, 1> {
(self.w.val() - init_w)
.powi_scalar(2)
.div(params_stddev.powi_scalar(2))
.sum()
.mul_scalar(gamma * batch_size as f64 / total_size as f64)
}
}

impl<B: AutodiffBackend> Model<B> {
Expand Down Expand Up @@ -189,6 +209,8 @@ pub(crate) struct TrainingConfig {
pub learning_rate: f64,
#[config(default = 64)]
pub max_seq_len: usize,
#[config(default = 1.0)]
pub gamma: f64,
}

pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 {
Expand Down Expand Up @@ -345,7 +367,8 @@ fn train<B: AutodiffBackend>(
B::seed(config.seed);

// Training data
let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs;
let total_size = train_set.len();
let iterations = (total_size / config.batch_size + 1) * config.num_epochs;
let batch_dataset = BatchTensorDataset::<B>::new(
FSRSDataset::from(train_set),
config.batch_size,
Expand All @@ -356,7 +379,7 @@ fn train<B: AutodiffBackend>(
let batch_dataset = BatchTensorDataset::<B::InnerBackend>::new(
FSRSDataset::from(test_set.clone()),
config.batch_size,
device,
device.clone(),
);
let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed);

Expand All @@ -371,6 +394,8 @@ fn train<B: AutodiffBackend>(
};

let mut model: Model<B> = config.model.init();
let init_w = model.w.val();
let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device);
let mut optim = config.optimizer.init::<B, Model<B>>();

let mut best_loss = f64::INFINITY;
Expand All @@ -380,8 +405,16 @@ fn train<B: AutodiffBackend>(
let mut iteration = 0;
while let Some(item) = iterator.next() {
iteration += 1;
let real_batch_size = item.delta_ts.shape().dims[0];
let lr = LrScheduler::<B>::step(&mut lr_scheduler);
let progress = iterator.progress();
let penalty = model.l2_regularization(
init_w.clone(),
params_stddev.clone(),
real_batch_size,
total_size,
config.gamma,
);
let loss = model.forward_classification(
item.t_historys,
item.r_historys,
Expand All @@ -390,7 +423,7 @@ fn train<B: AutodiffBackend>(
item.weights,
Reduction::Sum,
);
let mut gradients = loss.backward();
let mut gradients = (loss + penalty).backward();
if model.config.freeze_initial_stability {
gradients = model.freeze_initial_stability(gradients);
}
Expand Down Expand Up @@ -420,6 +453,14 @@ fn train<B: AutodiffBackend>(
let model_valid = model.valid();
let mut loss_valid = 0.0;
for batch in dataloader_valid.iter() {
let real_batch_size = batch.delta_ts.shape().dims[0];
let penalty = model_valid.l2_regularization(
init_w.valid(),
params_stddev.valid(),
real_batch_size,
total_size,
config.gamma,
);
let loss = model_valid.forward_classification(
batch.t_historys,
batch.r_historys,
Expand All @@ -429,7 +470,8 @@ fn train<B: AutodiffBackend>(
Reduction::Sum,
);
let loss = loss.into_data().convert::<f64>().value[0];
loss_valid += loss;
let penalty = penalty.into_data().convert::<f64>().value[0];
loss_valid += loss + penalty;

if interrupter.should_stop() {
break;
Expand Down Expand Up @@ -494,6 +536,8 @@ mod tests {
let device = NdArrayDevice::Cpu;
type B = Autodiff<NdArray<f32>>;
let mut model: Model<B> = config.init();
let init_w = model.w.val();
let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device);

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Expand Down Expand Up @@ -563,6 +607,38 @@ mod tests {
])
);

let penalty =
model.l2_regularization(init_w.clone(), params_stddev.clone(), 512, 1000, 2.0);
assert_eq!(
penalty.clone().into_data().convert::<f32>().value[0],
0.64689976
);

let gradients = penalty.backward();
let w_grad = model.w.grad(&gradients).unwrap();
Data::from([
0.0018749383,
0.00090389,
0.00026177685,
-0.00010645759,
0.27080965,
-1.0448978,
-0.18249036,
5.688889,
-0.5119995,
2.528395,
-0.7086509,
1.1237301,
-12.799997,
4.179591,
0.25213587,
1.3107198,
-0.07721739,
-1.1237309,
-0.5385926,
])
.assert_approx_eq(&w_grad.clone().into_data(), 5);

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Data::from([
Expand Down

0 comments on commit f11c30d

Please sign in to comment.