diff --git a/.github/workflows/check.sh b/.github/workflows/check.sh index c2c1a64..8958f05 100755 --- a/.github/workflows/check.sh +++ b/.github/workflows/check.sh @@ -13,5 +13,5 @@ unzip *.zip RUSTDOCFLAGS="-D warnings" cargo doc --release -cargo install cargo-llvm-cov --locked +cargo install cargo-llvm-cov@0.6.15 --locked SKIP_TRAINING=1 cargo llvm-cov --release diff --git a/Cargo.lock b/Cargo.lock index 5c23bb2..0f6a99f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,7 +1077,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "2.0.1" +version = "2.0.2" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 5e57c46..8a2cc04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "2.0.1" +version = "2.0.2" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/training.rs b/src/training.rs index 1491c88..2b0f066 100644 --- a/src/training.rs +++ b/src/training.rs @@ -25,6 +25,11 @@ use log::info; use std::sync::{Arc, Mutex}; +static PARAMS_STDDEV: [f32; 19] = [ + 6.61, 9.52, 17.69, 27.74, 0.55, 0.28, 0.67, 0.12, 0.4, 0.18, 0.34, 0.27, 0.08, 0.14, 0.57, + 0.25, 1.03, 0.27, 0.39, +]; + pub struct BCELoss { backend: PhantomData, } @@ -70,6 +75,21 @@ impl Model { let retention = self.power_forgetting_curve(delta_ts, state.stability); BCELoss::new().forward(retention, labels.float(), weights, reduce) } + + pub(crate) fn l2_regularization( + &self, + init_w: Tensor, + params_stddev: Tensor, + batch_size: usize, + total_size: usize, + gamma: f64, + ) -> Tensor { + (self.w.val() - init_w) + .powi_scalar(2) + .div(params_stddev.powi_scalar(2)) + .sum() + .mul_scalar(gamma * batch_size as f64 / total_size as f64) + } } impl Model { @@ -189,6 +209,8 @@ pub(crate) struct TrainingConfig { pub learning_rate: f64, #[config(default = 64)] pub max_seq_len: usize, + #[config(default = 1.0)] + pub gamma: f64, } pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 { @@ -345,7 +367,8 @@ fn train( B::seed(config.seed); // Training data - let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs; + let total_size = train_set.len(); + let iterations = (total_size / config.batch_size + 1) * config.num_epochs; let batch_dataset = BatchTensorDataset::::new( FSRSDataset::from(train_set), config.batch_size, @@ -356,7 +379,7 @@ fn train( let batch_dataset = BatchTensorDataset::::new( FSRSDataset::from(test_set.clone()), config.batch_size, - device, + device.clone(), ); let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed); @@ -371,6 +394,8 @@ fn train( }; let mut model: Model = config.model.init(); + let init_w = model.w.val(); + let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device); let mut optim = config.optimizer.init::>(); let mut best_loss = f64::INFINITY; @@ -380,8 +405,16 @@ fn train( let mut iteration = 0; while let Some(item) = iterator.next() { iteration += 1; + let real_batch_size = item.delta_ts.shape().dims[0]; let lr = LrScheduler::::step(&mut lr_scheduler); let progress = iterator.progress(); + let penalty = model.l2_regularization( + init_w.clone(), + params_stddev.clone(), + real_batch_size, + total_size, + config.gamma, + ); let loss = model.forward_classification( item.t_historys, item.r_historys, @@ -390,7 +423,7 @@ fn train( item.weights, Reduction::Sum, ); - let mut gradients = loss.backward(); + let mut gradients = (loss + penalty).backward(); if model.config.freeze_initial_stability { gradients = model.freeze_initial_stability(gradients); } @@ -420,6 +453,14 @@ fn train( let model_valid = model.valid(); let mut loss_valid = 0.0; for batch in dataloader_valid.iter() { + let real_batch_size = batch.delta_ts.shape().dims[0]; + let penalty = model_valid.l2_regularization( + init_w.valid(), + params_stddev.valid(), + real_batch_size, + total_size, + config.gamma, + ); let loss = model_valid.forward_classification( batch.t_historys, batch.r_historys, @@ -429,7 +470,8 @@ fn train( Reduction::Sum, ); let loss = loss.into_data().convert::().value[0]; - loss_valid += loss; + let penalty = penalty.into_data().convert::().value[0]; + loss_valid += loss + penalty; if interrupter.should_stop() { break; @@ -494,6 +536,8 @@ mod tests { let device = NdArrayDevice::Cpu; type B = Autodiff>; let mut model: Model = config.init(); + let init_w = model.w.val(); + let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device); let item = FSRSBatch { t_historys: Tensor::from_floats( @@ -563,6 +607,38 @@ mod tests { ]) ); + let penalty = + model.l2_regularization(init_w.clone(), params_stddev.clone(), 512, 1000, 2.0); + assert_eq!( + penalty.clone().into_data().convert::().value[0], + 0.64689976 + ); + + let gradients = penalty.backward(); + let w_grad = model.w.grad(&gradients).unwrap(); + Data::from([ + 0.0018749383, + 0.00090389, + 0.00026177685, + -0.00010645759, + 0.27080965, + -1.0448978, + -0.18249036, + 5.688889, + -0.5119995, + 2.528395, + -0.7086509, + 1.1237301, + -12.799997, + 4.179591, + 0.25213587, + 1.3107198, + -0.07721739, + -1.1237309, + -0.5385926, + ]) + .assert_approx_eq(&w_grad.clone().into_data(), 5); + let item = FSRSBatch { t_historys: Tensor::from_floats( Data::from([